• Docs >
  • Get started with your own first training loop

Get started with your own first training loop

Author: Vincent Moens


To run this tutorial in a notebook, add an installation cell at the beginning containing:

!pip install tensordict
!pip install torchrl

Time to wrap up everything we’ve learned so far in this Getting Started series!

In this tutorial, we will be writing the most basic training loop there is using only components we have presented in the previous lessons.

We’ll be using DQN with a CartPole environment as a prototypical example.

We will be voluntarily keeping the verbosity to its minimum, only linking each section to the related tutorial.

Building the environment

We’ll be using a gym environment with a StepCounter transform. If you need a refresher, check our these features are presented in the environment tutorial.

import torch


import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv

env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

Designing a policy

The next step is to build our policy. We’ll be making a regular, deterministic version of the actor to be used within the loss module and during evaluation. Next, we will augment it with an exploration module for inference.

from torchrl.modules import EGreedyModule, MLP, QValueModule

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(env.action_spec))
exploration_module = EGreedyModule(
    env.action_spec, annealing_num_steps=100_000, eps_init=0.5
policy_explore = Seq(policy, exploration_module)

Data Collector and replay buffer

Here comes the data part: we need a data collector to easily get batches of data and a replay buffer to store that data for training.

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer

init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

from torch.optim import Adam

Loss module and optimizer

We build our loss as indicated in the dedicated tutorial, with its optimizer and target parameter updater:

from torchrl.objectives import DQNLoss, SoftUpdate

loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True)
optim = Adam(loss.parameters(), lr=0.02)
updater = SoftUpdate(loss, eps=0.99)


We’ll be using a CSV logger to log our results, and save rendered videos.

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder

Training loop

Instead of fixing a specific number of iterations to run, we will keep on training the network until it reaches a certain performance (arbitrarily defined as 200 steps in the environment – with CartPole, success is defined as having longer trajectories).

total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    max_length = rb[:]["next", "step_count"].max()
    if len(rb) > init_rand_steps:
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128)
            loss_vals = loss(sample)
            # Update exploration factor
            # Update target params
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum()
    if max_length > 200:

t1 = time.time()

    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."


Finally, we run the environment for as many steps as we can and save the video locally (notice that we are not exploring).

record_env.rollout(max_steps=1000, policy=policy)

This is what your rendered CartPole video will look like after a full training loop:


This concludes our series of “Getting started with TorchRL” tutorials! Feel free to share feedback about it on GitHub.

Total running time of the script: (0 minutes 21.525 seconds)

Estimated memory usage: 12 MB

Gallery generated by Sphinx-Gallery


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources