• Docs >
  • Introduction to TorchRL
Shortcuts

Introduction to TorchRL

This demo was presented at ICML 2022 on the industry demo day.

It gives a good overview of TorchRL functionalities. Feel free to reach out to vmoens@fb.com or submit issues if you have questions or comments about it.

TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.

https://github.com/pytorch/rl

The PyTorch ecosystem team (Meta) has decided to invest in that library to provide a leading platform to develop RL solutions in research settings.

It provides pytorch and python-first, low and high level abstractions # for RL that are intended to be efficient, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.

This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar (torchrl/envs), transforms, models, data utilities (e.g. collectors and containers), etc. TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional.

Content:
../_images/aafig-1f3b6e30cfaaae3f21ed3b55ebfc722276b91b6f.svg

Unlike other domains, RL is less about media than algorithms. As such, it is harder to make truly independent components.

What TorchRL is not:

  • a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms, but we provide these algorithms only as examples of how to use the library.

  • a research framework: modularity in TorchRL comes in two flavors. First, we try to build re-usable components, such that they can be easily swapped with each other. Second, we make our best such that components can be used independently of the rest of the library.

TorchRL has very few core dependencies, predominantly PyTorch and numpy. All other dependencies (gym, torchvision, wandb / tensorboard) are optional.

Data

TensorDict

import torch
from tensordict import TensorDict

Let’s create a TensorDict.

batch_size = 5
tensordict = TensorDict(
    source={
        "key 1": torch.zeros(batch_size, 3),
        "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)
print(tensordict)
TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

You can index a TensorDict as well as query keys.

print(tensordict[2])
print(tensordict["key 1"] is tensordict.get("key 1"))
TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
True

The following shows how to stack multiple TensorDicts.

tensordict1 = TensorDict(
    source={
        "key 1": torch.zeros(batch_size, 1),
        "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

tensordict2 = TensorDict(
    source={
        "key 1": torch.ones(batch_size, 1),
        "key 2": torch.ones(batch_size, 5, 6, dtype=torch.bool),
    },
    batch_size=[batch_size],
)

tensordict = torch.stack([tensordict1, tensordict2], 0)
tensordict.batch_size, tensordict["key 1"]
(torch.Size([2, 5]), tensor([[[0.],
         [0.],
         [0.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.]]]))

Here are some other functionalities of TensorDict.

print(
    "view(-1): ",
    tensordict.view(-1).batch_size,
    tensordict.view(-1).get("key 1").shape,
)

print("to device: ", tensordict.to("cpu"))

# print("pin_memory: ", tensordict.pin_memory())

print("share memory: ", tensordict.share_memory_())

print(
    "permute(1, 0): ",
    tensordict.permute(1, 0).batch_size,
    tensordict.permute(1, 0).get("key 1").shape,
)

print(
    "expand: ",
    tensordict.expand(3, *tensordict.batch_size).batch_size,
    tensordict.expand(3, *tensordict.batch_size).get("key 1").shape,
)
view(-1):  torch.Size([10]) torch.Size([10, 1])
to device:  TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([2, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([2, 5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2, 5]),
    device=cpu,
    is_shared=False)
share memory:  TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([2, 5, 1]), device=cpu, dtype=torch.float32, is_shared=True),
        key 2: Tensor(shape=torch.Size([2, 5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=True)},
    batch_size=torch.Size([2, 5]),
    device=None,
    is_shared=True)
permute(1, 0):  torch.Size([5, 2]) torch.Size([5, 2, 1])
expand:  torch.Size([3, 2, 5]) torch.Size([3, 2, 5, 1])

You can create a nested TensorDict as well.

tensordict = TensorDict(
    source={
        "key 1": torch.zeros(batch_size, 3),
        "key 2": TensorDict(
            source={"sub-key 1": torch.zeros(batch_size, 2, 1)},
            batch_size=[batch_size, 2],
        ),
    },
    batch_size=[batch_size],
)
tensordict
TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: TensorDict(
            fields={
                sub-key 1: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

Replay buffers

from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer
rb = ReplayBuffer(collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
[1]
rb.extend([2, 3])
rb.sample(3)
[2, 1, 3]
rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)

Here are examples of using a replaybuffer with tensordicts.

collate_fn = torch.stack
rb = ReplayBuffer(collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)
1
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
print(len(rb))
print(rb.sample(10))
print(rb.sample(2).contiguous())
3
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer

rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
tensordict_sample = rb.sample(2).contiguous()
tensordict_sample
TensorDict(
    fields={
        _weight: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        index: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
tensordict_sample["index"]
tensor([0, 0])
tensordict_sample["td_error"] = torch.rand(2)
rb.update_tensordict_priority(tensordict_sample)

for i, val in enumerate(rb._sampler._sum_tree):
    print(i, val)
    if i == len(rb):
        break

try:
    import gymnasium as gym
except ModuleNotFoundError:
    import gym
0 0.28791671991348267
1 1.0
2 0.0

Envs

from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend

gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")
tensordict = env.reset()
env.rand_step(tensordict)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Changing environments config

env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        pixels: Tensor(shape=torch.Size([500, 500, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
env.close()
del env
from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['pixels']),
            ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels'])))

Transforms

from torchrl.envs import (
    Compose,
    NoopResetEnv,
    ObservationNorm,
    StepCounter,
    ToTensorImage,
    TransformedEnv,
)

base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['pixels']),
            ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels'])))
env.reset()
Traceback (most recent call last):
  File "/pytorch/rl/docs/source/reference/generated/tutorials/torchrl_demo.py", line 424, in <module>
    env.reset()
  File "/pytorch/rl/torchrl/envs/common.py", line 2256, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/pytorch/rl/torchrl/envs/transforms/transforms.py", line 823, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
  File "/pytorch/rl/torchrl/envs/transforms/transforms.py", line 1138, in _reset
    tensordict_reset = t._reset(tensordict, tensordict_reset)
  File "/pytorch/rl/torchrl/envs/transforms/transforms.py", line 4533, in _reset
    tensordict = parent.rand_step(tensordict)
  File "/pytorch/rl/torchrl/envs/common.py", line 2413, in rand_step
    return self.step(tensordict)
  File "/pytorch/rl/torchrl/envs/common.py", line 1592, in step
    next_tensordict = self._step(tensordict)
  File "/pytorch/rl/torchrl/envs/transforms/transforms.py", line 788, in _step
    next_tensordict = self.base_env._step(tensordict_in)
  File "/pytorch/rl/torchrl/envs/gym_like.py", line 308, in _step
    ) = self._output_transform(self._env.step(action))
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/core.py", line 385, in step
    return self.observation(observation), reward, terminated, truncated, info
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/wrappers/pixel_observation.py", line 174, in observation
    pixel_observation = self._add_pixel_observation(observation)
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/wrappers/pixel_observation.py", line 186, in _add_pixel_observation
    pixel_observations = {
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/wrappers/pixel_observation.py", line 187, in <dictcomp>
    pixel_key: self._render(**self._render_kwargs[pixel_key])
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/wrappers/pixel_observation.py", line 204, in _render
    render = self.env.render(*args, **kwargs)
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/core.py", line 329, in render
    return self.env.render(*args, **kwargs)
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/wrappers/order_enforcing.py", line 51, in render
    return self.env.render(*args, **kwargs)
  File "/pytorch/rl/env/lib/python3.10/site-packages/gym/envs/classic_control/pendulum.py", line 231, in render
    scale_img = pygame.transform.smoothscale(
TypeError: size must be two numbers
print("env: ", env)
print("last transform parent: ", env.transform[2].parent)

Vectorized Environments

from torchrl.envs import ParallelEnv


def make_env():
    # You can control whether to use gym or gymnasium for your env
    with set_gym_backend("gym"):
        return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)


base_env = ParallelEnv(
    4,
    make_env,
    mp_start_method="fork",  # This will break on Windows machines! Remove and decorate with if __name__ == "__main__"
)
env = TransformedEnv(
    base_env, Compose(StepCounter(), ToTensorImage())
)  # applies transforms on batch of envs
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))
env.reset()
print(env.action_spec)

env.close()
del env

Modules

Models

Example of a MLP model:

from torch import nn
from torchrl.modules import ConvNet, MLP
from torchrl.modules.models.utils import SquashDims

net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)
print(net(torch.randn(10, 3)).shape)

Example of a CNN model:

cnn = ConvNet(
    num_cells=[32, 64],
    kernel_sizes=[8, 4],
    strides=[2, 1],
    aggregator_class=SquashDims,
)
print(cnn)
print(cnn(torch.randn(10, 3, 32, 32)).shape)  # last tensor is squashed

TensorDictModules

from tensordict.nn import TensorDictModule

tensordict = TensorDict({"key 1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key 1"], out_keys=["key 2"])
td_module(tensordict)
print(tensordict)

Sequences of Modules

from tensordict.nn import TensorDictSequential

backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(
    backbone_module, in_keys=["observation"], out_keys=["hidden"]
)
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])

sequence = TensorDictSequential(backbone, actor, value)
print(sequence)
print(sequence.in_keys, sequence.out_keys)
tensordict = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
backbone(tensordict)
actor(tensordict)
value(tensordict)
tensordict = TensorDict(
    {"observation": torch.randn(3, 5)},
    [3],
)
sequence(tensordict)
print(tensordict)

Functional Programming (Ensembling / Meta-RL)

from tensordict import TensorDict

params = TensorDict.from_module(sequence)
print("extracted params", params)

functional call using tensordict:

with params.to_module(sequence):
    sequence(tensordict)

Using vectorized map for model ensembling

from torch import vmap

params_expand = params.expand(4)


def exec_sequence(params, data):
    with params.to_module(sequence):
        return sequence(data)


tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, tensordict)
print(tensordict_exp)

Specialized Classes

torch.manual_seed(0)
from torchrl.data import Bounded
from torchrl.modules import SafeModule

spec = Bounded(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = SafeModule(
    module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True
)
tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(tensordict)["action"]
tensordict = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[])
module(tensordict)["action"]  # safe=True projects the result within the set
from torchrl.modules import Actor

base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(tensordict)  # action is the default value

from tensordict.nn import (
    ProbabilisticTensorDictModule,
    ProbabilisticTensorDictSequential,
)
# Probabilistic modules
from torchrl.modules import NormalParamExtractor, TanhNormal

td = TensorDict({"input": torch.randn(3, 5)}, [3])
net = nn.Sequential(
    nn.Linear(5, 4), NormalParamExtractor()
)  # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=False,
    ),
)
td_module(td)
print(td)
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3])
td_module = ProbabilisticTensorDictSequential(
    module,
    ProbabilisticTensorDictModule(
        in_keys=["loc", "scale"],
        out_keys=["action"],
        distribution_class=TanhNormal,
        return_log_prob=True,
    ),
)
td_module(td)
print(td)
# Sampling vs mode / mean
from torchrl.envs.utils import ExplorationType, set_exploration_type

td = TensorDict({"input": torch.randn(3, 5)}, [3])

torch.manual_seed(0)
with set_exploration_type(ExplorationType.RANDOM):
    td_module(td)
    print("random:", td["action"])

with set_exploration_type(ExplorationType.DETERMINISTIC):
    td_module(td)
    print("mode:", td["action"])

Using Environments and Modules

from torchrl.envs.utils import step_mdp

env = GymEnv("Pendulum-v1")

action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = SafeModule(
    actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"]
)

torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
tensordict = env.reset()
tensordicts = TensorDict({}, [max_steps])
for i in range(max_steps):
    actor(tensordict)
    tensordicts[i] = env.step(tensordict)
    if tensordict["done"].any():
        break
    tensordict = step_mdp(tensordict)  # roughly equivalent to obs = next_obs

tensordicts_prealloc = tensordicts.clone()
print("total steps:", i)
print(tensordicts)
# equivalent
torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
tensordict = env.reset()
tensordicts = []
for _ in range(max_steps):
    actor(tensordict)
    tensordicts.append(env.step(tensordict))
    if tensordict["done"].any():
        break
    tensordict = step_mdp(tensordict)  # roughly equivalent to obs = next_obs
tensordicts_stack = torch.stack(tensordicts, 0)
print("total steps:", i)
print(tensordicts_stack)
(tensordicts_stack == tensordicts_prealloc).all()
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout


(tensordict_rollout == tensordicts_prealloc).all()

from tensordict.nn import TensorDictModule

Collectors

from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector

from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.libs.gym import GymEnv

EnvCreator makes sure that we can send a lambda function from process to process We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited.

parallel_env = SerialEnv(
    3,
    EnvCreator(lambda: GymEnv("Pendulum-v1")),
)
create_env_fn = [parallel_env, parallel_env]

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])

Sync data collector

devices = ["cpu", "cpu"]

collector = MultiSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)
for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
# async data collector: keeps working while you update your model
collector = MultiaSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    device=devices,
)

for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
collector.shutdown()
del collector
del create_env_fn
del parallel_env

Objectives

# TorchRL delivers meta-RL compatible loss functions
# Disclaimer: This APi may change in the future
from torchrl.objectives import DDPGLoss

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])


class ConcatModule(nn.Linear):
    def forward(self, obs, action):
        return super().forward(torch.cat([obs, action], -1))


value_module = ConcatModule(4, 1)
value = TensorDictModule(
    value_module, in_keys=["observation", "action"], out_keys=["state_action_value"]
)

loss_fn = DDPGLoss(actor, value)
loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99)
tensordict = TensorDict(
    {
        "observation": torch.randn(10, 3),
        "next": {
            "observation": torch.randn(10, 3),
            "reward": torch.randn(10, 1),
            "done": torch.zeros(10, 1, dtype=torch.bool),
        },
        "action": torch.randn(10, 1),
    },
    batch_size=[10],
    device="cpu",
)
loss_td = loss_fn(tensordict)
print(loss_td)
print(tensordict)

State of the Library

TorchRL is currently an alpha-release: there may be bugs and there is no guarantee about BC-breaking changes. We should be able to move to a beta-release by the end of the year. Our roadmap to get there comprises:

  • Distributed solutions

  • Offline RL

  • Greater support for meta-RL

  • Multi-task and hierarchical RL

Contributing

We are actively looking for contributors and early users. If you’re working in RL (or just curious), try it! Give us feedback: what will make the success of TorchRL is how well it covers researchers needs. To do that, we need their input! Since the library is nascent, it is a great time for you to shape it the way you want!

Installing the Library

The library is on PyPI: pip install torchrl

Total running time of the script: (1 minutes 32.839 seconds)

Estimated memory usage: 5022 MB

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources