torchrl.envs package

TorchRL offers an API to handle environments of different backends, such as gym, dm-control, dm-lab, model-based environments as well as custom environments. The goal is to be able to swap environments in an experiment with little or no effort, even if these environments are simulated using different libraries. TorchRL offers some out-of-the-box environment wrappers under torchrl.envs.libs, which we hope can be easily imitated for other libraries. The parent class EnvBase is a torch.nn.Module subclass that implements some typical environment methods using tensordict.TensorDict as a data organiser. This allows this class to be generic and to handle an arbitrary number of input and outputs, as well as nested or batched data structures.

Each env will have the following attributes:

  • env.batch_size: a torch.Size representing the number of envs batched together.

  • env.device: the device where the input and output tensordict are expected to live. The environment device does not mean that the actual step operations will be computed on device (this is the responsibility of the backend, with which TorchRL can do little). The device of an environment just represents the device where the data is to be expected when input to the environment or retrieved from it. TorchRL takes care of mapping the data to the desired device. This is especially useful for transforms (see below). For parametric environments (e.g. model-based environments), the device does represent the hardware that will be used to compute the operations.

  • env.observation_spec: a CompositeSpec object containing all the observation key-spec pairs.

  • env.state_spec: a CompositeSpec object containing all the input key-spec pairs (except action). For most stateful environments, this container will be empty.

  • env.action_spec: a TensorSpec object representing the action spec.

  • env.reward_spec: a TensorSpec object representing the reward spec.

  • env.done_spec: a TensorSpec object representing the done-flag spec.

  • env.input_spec: a CompositeSpec object containing all the input keys ("_action_spec" and "_state_spec"). It is locked and should not be modified directly.

  • env.output_spec: a CompositeSpec object containing all the output keys ("_observation_spec", "_reward_spec" and "_done_spec"). It is locked and should not be modified directly.

Importantly, the environment spec shapes should contain the batch size, e.g. an environment with env.batch_size == torch.Size([4]) should have an env.action_spec with shape torch.Size([4, action_size]). This is helpful when preallocation tensors, checking shape consistency etc.

With these, the following methods are implemented:

  • env.reset(): a reset method that may (but not necessarily requires to) take a tensordict.TensorDict input. It return the first tensordict of a rollout, usually containing a "done" state and a set of observations. If not present, a “reward” key will be instantiated with 0s and the appropriate shape.

  • env.step(): a step method that takes a tensordict.TensorDict input containing an input action as well as other inputs (for model-based or stateless environments, for instance).

  • env.set_seed(): a seeding method that will return the next seed to be used in a multi-env setting. This next seed is deterministically computed from the preceding one, such that one can seed multiple environments with a different seed without risking to overlap seeds in consecutive experiments, while still having reproducible results.

  • env.rollout(): executes a rollout in the environment for a maximum number of steps (max_steps=N) and using a policy (policy=model). The policy should be coded using a tensordict.nn.TensorDictModule (or any other tensordict.TensorDict-compatible module). The resulting tensordict.TensorDict instance will be marked with a trailing "time" named dimension that can be used by other modules to treat this batched dimension as it should.

The following figure summarizes how a rollout is executed in torchrl.


TorchRL rollouts using TensorDict.

In brief, a TensorDict is created by the reset() method, then populated with an action by the policy before being passed to the step() method which writes the observations, done flag and reward under the "next" entry. The result of this call is stored for delivery and the "next" entry is gathered by the step_mdp() function.


The Gym(nasium) API recently shifted to a splitting of the "done" state into a terminated (the env is done and results should not be trusted) and truncated (the maximum number of steps is reached) flags. In TorchRL, "done" usually refers to "terminated". Truncation is achieved via the StepCounter transform class, and the output key will be "truncated" if not chosen to be something else (e.g. StepCounter(max_steps=100, truncated_key="done")). TorchRL’s collectors and rollout methods will be looking for one of these keys when assessing if the env should be reset.


The torchrl.collectors.utils.split_trajectories function can be used to slice adjacent trajectories. It relies on a "traj_ids" entry in the input tensordict, or to the junction of "done" and "truncated" key if the "traj_ids" is missing.


In some contexts, it can be useful to mark the first step of a trajectory. TorchRL provides such functionality through the InitTracker transform.

Our environment tutorial provides more information on how to design a custom environment from scratch.

EnvBase(*args[, _inplace_update, _batch_locked])

Abstract environment parent class.

GymLikeEnv(*args, **kwargs)

A gym-like env is an environment.

EnvMetaData(tensordict, specs, batch_size, ...)

A class for environment meta-data storage and passing in multiprocessed settings.

Vectorized envs

Vectorized (or better: parallel) environments is a common feature in Reinforcement Learning where executing the environment step can be cpu-intensive. Some libraries such as gym3 or EnvPool offer interfaces to execute batches of environments simultaneously. While they often offer a very competitive computational advantage, they do not necessarily scale to the wide variety of environment libraries supported by TorchRL. Therefore, TorchRL offers its own, generic ParallelEnv class to run multiple environments in parallel. As this class inherits from SerialEnv, it enjoys the exact same API as other environment. Of course, a ParallelEnv will have a batch size that corresponds to its environment count:

It is important that your environment specs match the input and output that it sends and receives, as ParallelEnv will create buffers from these specs to communicate with the spawn processes. Check the check_env_specs() method for a sanity check.

Parallel environment
     >>> def make_env():
     ...     return GymEnv("Pendulum-v1", from_pixels=True, g=9.81, device="cuda:0")
     >>> check_env_specs(env)  # this must pass for ParallelEnv to work
     >>> env = ParallelEnv(4, make_env)
     >>> print(env.batch_size)

ParallelEnv allows to retrieve the attributes from its contained environments: one can simply call:

Parallel environment attributes
     >>> a, b, c, d = env.g  # gets the g-force of the various envs, which we set to 9.81 before
     >>> print(a)

It is also possible to reset some but not all of the environments:

Parallel environment reset
     >>> tensordict = TensorDict({"reset_workers": [True, False, True, True]}, [4])
     >>> env.reset(tensordict)
             done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
             pixels: Tensor(torch.Size([4, 500, 500, 3]), dtype=torch.uint8),
             reset_workers: Tensor(torch.Size([4]), dtype=torch.bool)},


A note on performance: launching a ParallelEnv can take quite some time as it requires to launch as many python instances as there are processes. Due to the time that it takes to run import torch (and other imports), starting the parallel env can be a bottleneck. This is why, for instance, TorchRL tests are so slow. Once the environment is launched, a great speedup should be observed.


TorchRL requires precise specs: Another thing to take in consideration is that ParallelEnv (as well as data collectors) will create data buffers based on the environment specs to pass data from one process to another. This means that a misspecified spec (input, observation or reward) will cause a breakage at runtime as the data can’t be written on the preallocated buffer. In general, an environment should be tested using the check_env_specs() test function before being used in a ParallelEnv. This function will raise an assertion error whenever the preallocated buffer and the collected data mismatch.

We also offer the SerialEnv class that enjoys the exact same API but is executed serially. This is mostly useful for testing purposes, when one wants to assess the behaviour of a ParallelEnv without launching the subprocesses.

In addition to ParallelEnv, which offers process-based parallelism, we also provide a way to create multithreaded environments with MultiThreadedEnv. This class uses EnvPool library underneath, which allows for higher performance, but at the same time restricts flexibility - one can only create environments implemented in EnvPool. This covers many popular RL environments types (Atari, Classic Control, etc.), but one can not use an arbitrary TorchRL environment, as it is possible with ParallelEnv. Run benchmarks/ to compare performance of different ways to parallelize batched environments.

SerialEnv(*args[, _inplace_update, ...])

Creates a series of environments in the same process.Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely.

ParallelEnv(*args[, _inplace_update, ...])

Creates one environment per process.

MultiThreadedEnv(*args[, _inplace_update, ...])

Multithreaded execution of environments based on EnvPool.

EnvCreator(create_env_fn[, ...])

Environment creator class.


In most cases, the raw output of an environment must be treated before being passed to another object (such as a policy or a value operator). To do this, TorchRL provides a set of transforms that aim at reproducing the transform logic of torch.distributions.Transform and torchvision.transforms. Our environment tutorial provides more information on how to design a custom transform.

Transformed environments are build using the TransformedEnv primitive. Composed transforms are built using the Compose class:

Transformed environment
     >>> base_env = GymEnv("Pendulum-v1", from_pixels=True, device="cuda:0")
     >>> transform = Compose(ToTensorImage(in_keys=["pixels"]), Resize(64, 64, in_keys=["pixels"]))
     >>> env = TransformedEnv(base_env, transform)

By default, the transformed environment will inherit the device of the base_env that is passed to it. The transforms will then be executed on that device. It is now apparent that this can bring a significant speedup depending on the kind of operations that is to be computed.

A great advantage of environment wrappers is that one can consult the environment up to that wrapper. The same can be achieved with TorchRL transformed environments: the parent attribute will return a new TransformedEnv with all the transforms up to the transform of interest. Re-using the example above:

Transform parent
     >>> resize_parent = env.transform[-1].parent  # returns the same as TransformedEnv(base_env, transform[:-1])

Transformed environment can be used with vectorized environments. Since each transform uses a "in_keys"/"out_keys" set of keyword argument, it is also easy to root the transform graph to each component of the observation data (e.g. pixels or states etc).

Transforms also have an inv method that is called before the action is applied in reverse order over the composed transform chain: this allows to apply transforms to data in the environment before the action is taken in the environment. The keys to be included in this inverse transform are passed through the "in_keys_inv" keyword argument:

Inverse transform
     >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"]))  # will map the action from float32 to float64 before calling the base_env.step

Cloning transforms

Because transforms appended to an environment are “registered” to this environment through the transform.parent property, when manipulating transforms we should keep in mind that the parent may come and go following what is being done with the transform. Here are some examples: if we get a single transform from a Compose object, this transform will keep its parent:

>>> third_transform = env.transform[2]
>>> assert third_transform.parent is not None

This means that using this transform for another environment is prohibited, as the other environment would replace the parent and this may lead to unexpected behviours. Fortunately, the Transform class comes with a clone() method that will erase the parent while keeping the identity of all the registered buffers:

>>> TransformedEnv(base_env, third_transform)  # raises an Exception as third_transform already has a parent
>>> TransformedEnv(base_env, third_transform.clone())  # works

On a single process or if the buffers are placed in shared memory, this will result in all the clone transforms to keep the same behaviour even if the buffers are changed in place (which is what will happen with the CatFrames transform, for instance). In distributed settings, this may not hold and one should be careful about the expected behaviour of the cloned transforms in this context. Finally, notice that indexing multiple transforms from a Compose transform may also result in loss of parenthood for these transforms: the reason is that indexing a Compose transform results in another Compose transform that does not have a parent environment. Hence, we have to clone the sub-transforms to be able to create this other composition:

>>> env = TransformedEnv(base_env, Compose(transform1, transform2, transform3))
>>> last_two = env.transform[-2:]
>>> assert isinstance(last_two, Compose)
>>> assert last_two.parent is None
>>> assert last_two[0] is not transform2
>>> assert isinstance(last_two[0], type(transform2))  # and the buffers will match
>>> assert last_two[1] is not transform3
>>> assert isinstance(last_two[1], type(transform3))  # and the buffers will match

Transform(in_keys[, out_keys, in_keys_inv, ...])

Environment transform parent class.

TransformedEnv(*args[, _inplace_update, ...])

A transformed_in environment.

BinarizeReward([in_keys, out_keys])

Maps the reward to a binary value (0 or 1) if the reward is null or non-null, respectively.

CatFrames(N, dim[, in_keys, out_keys, padding])

Concatenates successive observation frames into a single tensor.

CatTensors([in_keys, out_key, dim, ...])

Concatenates several keys in a single tensor.

CenterCrop(w[, h, in_keys, out_keys])

Crops the center of an image.


Composes a chain of transforms.

DiscreteActionProjection(...[, action_key, ...])

Projects discrete actions from a high dimensional space to a low dimensional space.

DoubleToFloat([in_keys, in_keys_inv])

Maps actions float to double before they are called on the environment.


Excludes keys from the input tensordict.


This transform will check that all the items of the tensordict are finite, and raise an exception if they are not.

FlattenObservation(first_dim, last_dim[, ...])

Flatten adjacent dimensions of a tensor.


A frame-skip transform.

GrayScale([in_keys, out_keys])

Turns a pixel observation to grayscale.

gSDENoise([state_dim, action_dim, shape])

A gSDE noise initializer.


Reset tracker.

KLRewardTransform(actor[, coef, in_keys, ...])

A transform to add a KL[pi_current||pi_0] correction term to the reward.

NoopResetEnv([noops, random])

Runs a series of random actions when an environment is reset.

ObservationNorm([loc, scale, in_keys, ...])

Observation affine transformation layer.

ObservationTransform([in_keys, out_keys, ...])

Abstract class for transformations of the observations.


Calls pin_memory on the tensordict to facilitate writing on CUDA devices.

R3MTransform(*args, **kwargs)

R3M Transform class.

RandomCropTensorDict(sub_seq_len[, ...])

A trajectory sub-sampler for ReplayBuffer and modules.

RenameTransform(in_keys, out_keys[, ...])

A transform to rename entries in the output tensordict.

Resize(w, h[, interpolation, in_keys, out_keys])

Resizes an pixel observation.

RewardClipping([clamp_min, clamp_max, ...])

Clips the reward between clamp_min and clamp_max.

RewardScaling(loc, scale[, in_keys, ...])

Affine transform of the reward.

RewardSum([in_keys, out_keys])

Tracks episode cumulative rewards.

Reward2GoTransform([gamma, in_keys, out_keys])

Calculates the reward to go based on the episode reward and a discount factor.


Select keys from the input tensordict.

SqueezeTransform(*args, **kwargs)

Removes a dimension of size one at the specified position.

StepCounter([max_steps, truncated_key])

Counts the steps from a reset and sets the done state to True after a certain number of steps.

TargetReturn(target_return[, mode, in_keys, ...])

Sets a target return for the agent to achieve in the environment.

TensorDictPrimer([primers, random, ...])

A primer for TensorDict initialization at reset time.

TimeMaxPool([in_keys, out_keys, T])

Take the maximum value in each position over the last T observations.

ToTensorImage([from_int, unsqueeze, dtype, ...])

Transforms a numpy-like image (W x H x C) to a pytorch image (C x W x H).

UnsqueezeTransform(*args, **kwargs)

Inserts a dimension of size one at the specified position.

VecNorm([in_keys, shared_td, lock, decay, ...])

Moving average normalization layer for torchrl environments.

VIPRewardTransform(*args, **kwargs)

A VIP transform to compute rewards based on embedded similarity.

VIPTransform(*args, **kwargs)

VIP Transform class.


Recorders are transforms that register data as they come in, for logging purposes.

TensorDictRecorder(out_file_base[, ...])

TensorDict recorder.

VideoRecorder(logger, tag[, in_keys, skip, ...])

Video Recorder transform.


step_mdp(tensordict[, next_tensordict, ...])

Creates a new tensordict that reflects a step in time of the input tensordict.


Returns all the supported libraries.


alias of set_interaction_mode


alias of set_interaction_type


Deprecated Returns the current sampling mode.


Returns the current sampling type.

check_env_specs(env[, return_contiguous, ...])

Tests an environment specs against the results of short rollout.


Creates a CompositeSpec instance from a tensordict, assuming all values are unbounded.


ModelBasedEnvBase(*args, **kwargs)

Basic environnement for Model Based RL algorithms.

model_based.dreamer.DreamerEnv(*args, **kwargs)

Dreamer simulation environment.


TorchRL’s mission is to make the training of control and decision algorithm as easy as it gets, irrespective of the simulator being used (if any). Multiple wrappers are available for DMControl, Habitat, Jumanji and, naturally, for Gym.

This last library has a special status in the RL community as being the mostly used framework for coding simulators. Its successful API has been foundational and inspired many other frameworks, among which TorchRL. However, Gym has gone through multiple design changes and it is sometimes hard to accommodate these as an external adoption library: users usually have their “preferred” version of the library. Moreover, gym is now being maintained by another group under the “gymnasium” name, which does not facilitate code compatibility. In practice, we must consider that users may have a version of gym and gymnasium installed in the same virtual environment, and we must allow both to work concomittantly. Fortunately, TorchRL provides a solution for this problem: a special decorator set_gym_backend allows to control which library will be used in the relevant functions:

>>> from torchrl.envs.libs.gym import GymEnv, set_gym_backend, gym_backend
>>> import gymnasium, gym
>>> with set_gym_backend(gymnasium):
...     print(gym_backend())
...     env1 = GymEnv("Pendulum-v1")
<module 'gymnasium' from '/path/to/venv/python3.9/site-packages/gymnasium/'>
>>> with set_gym_backend(gym):
...     print(gym_backend())
...     env2 = GymEnv("Pendulum-v1")
<module 'gym' from '/path/to/venv/python3.9/site-packages/gym/'>
>>> print(env1._env.env.env)
<gymnasium.envs.classic_control.pendulum.PendulumEnv at 0x15147e190>
>>> print(env2._env.env.env)
<gym.envs.classic_control.pendulum.PendulumEnv at 0x1629916a0>

We can see that the two libraries modify the value returned by gym_backend() which can be further used to indicate which library needs to be used for the current computation. set_gym_backend is also a decorator: we can use it to tell to a specific function what gym backend needs to be used during its execution. The torchrl.envs.libs.gym.gym_backend() function allows you to gather the current gym backend or any of its modules:

>>> import mo_gymnasium
>>> with set_gym_backend("gym"):
...     wrappers = gym_backend('wrappers')
...     print(wrappers)
<module 'gym.wrappers' from '/path/to/venv/python3.9/site-packages/gym/wrappers/'>
>>> with set_gym_backend("gymnasium"):
...     wrappers = gym_backend('wrappers')
...     print(wrappers)
<module 'gymnasium.wrappers' from '/path/to/venv/python3.9/site-packages/gymnasium/wrappers/'>

Another tool that comes in handy with gym and other external dependencies is the torchrl._utils.implement_for class. Decorating a function with @implement_for will tell torchrl that, depending on the version indicated, a specific behaviour is to be expected. This allows us to easily support multiple versions of gym without requiring any effort from the user side. For example, considering that our virtual environment has the v0.26.2 installed, the following function will return 1 when queried:

>>> from torchrl._utils import implement_for
>>> @implement_for("gym", None, "0.26.0")
... def fun():
...     return 0
>>> @implement_for("gym", "0.26.0", None)
... def fun():
...     return 1
>>> fun()

brax.BraxEnv(*args[, _inplace_update, ...])

Google Brax environment wrapper.

brax.BraxWrapper(*args[, _inplace_update, ...])

Google Brax environment wrapper.

dm_control.DMControlEnv(*args, **kwargs)

DeepMind Control lab environment wrapper.

dm_control.DMControlWrapper(*args, **kwargs)

DeepMind Control lab environment wrapper.

gym.GymEnv(*args, **kwargs)

OpenAI Gym environment wrapper.

gym.GymWrapper(*args, **kwargs)

OpenAI Gym environment wrapper.

gym.MOGymEnv(*args, **kwargs)

FARAMA MO-Gymnasium environment wrapper.

gym.MOGymWrapper(*args, **kwargs)

FARAMA MO-Gymnasium environment wrapper.


Sets the gym-backend to a certain value.


Returns the gym backend, or a sumbodule of it.

habitat.HabitatEnv(*args, **kwargs)

A wrapper for habitat envs.

jumanji.JumanjiEnv(*args, **kwargs)

Jumanji environment wrapper.

jumanji.JumanjiWrapper(*args, **kwargs)

Jumanji environment wrapper.

openml.OpenMLEnv(*args[, _inplace_update, ...])

An environment interface to OpenML data to be used in bandits contexts.

vmas.VmasEnv(*args[, _inplace_update, ...])

Vmas environment wrapper.

vmas.VmasWrapper(*args[, _inplace_update, ...])

Vmas environment wrapper.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources