.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/torchrl_demo.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_tutorials_torchrl_demo.py>` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_torchrl_demo.py: Introduction to TorchRL ======================= This demo was presented at ICML 2022 on the industry demo day. .. GENERATED FROM PYTHON SOURCE LINES 7-186 It gives a good overview of TorchRL functionalities. Feel free to reach out to vmoens@fb.com or submit issues if you have questions or comments about it. TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. https://github.com/pytorch/rl The PyTorch ecosystem team (Meta) has decided to invest in that library to provide a leading platform to develop RL solutions in research settings. It provides pytorch and **python-first**, low and high level **abstractions** # for RL that are intended to be efficient, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort. This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar (torchrl/envs), transforms, models, data utilities (e.g. collectors and containers), etc. TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional. **Content**: .. aafig:: "torchrl" │ ├── "collectors" │ └── "collectors.py" │ │ │ └── "distributed" │ └── "default_configs.py" │ └── "generic.py" │ └── "ray.py" │ └── "rpc.py" │ └── "sync.py" ├── "data" │ │ │ ├── "datasets" │ │ └── "atari_dqn.py" │ │ └── "d4rl.py" │ │ └── "d4rl_infos.py" │ │ └── "gen_dgrl.py" │ │ └── "minari_data.py" │ │ └── "openml.py" │ │ └── "openx.py" │ │ └── "roboset.py" │ │ └── "vd4rl.py" │ ├── "postprocs" │ │ └── "postprocs.py" │ ├── "replay_buffers" │ │ └── "replay_buffers.py" │ │ └── "samplers.py" │ │ └── "storages.py" │ │ └── "writers.py" │ ├── "rlhf" │ │ └── "dataset.py" │ │ └── "prompt.py" │ │ └── "reward.py" │ └── "tensor_specs.py" ├── "envs" │ └── "batched_envs.py" │ └── "common.py" │ └── "env_creator.py" │ └── "gym_like.py" │ ├── "libs" │ │ └── "brax.py" │ │ └── "dm_control.py" │ │ └── "envpool.py" │ │ └── "gym.py" │ │ └── "habitat.py" │ │ └── "isaacgym.py" │ │ └── "jumanji.py" │ │ └── "openml.py" │ │ └── "pettingzoo.py" │ │ └── "robohive.py" │ │ └── "smacv2.py" │ │ └── "vmas.py" │ ├── "model_based" │ │ └── "common.py" │ │ └── "dreamer.py" │ ├── "transforms" │ │ └── "functional.py" │ │ └── "gym_transforms.py" │ │ └── "r3m.py" │ │ └── "rlhf.py" │ │ └── "vc1.py" │ │ └── "vip.py" │ └── "vec_envs.py" ├── "modules" │ ├── "distributions" │ │ └── "continuous.py" │ │ └── "discrete.py" │ │ └── "truncated_normal.py" │ ├── "models" │ │ └── "decision_transformer.py" │ │ └── "exploration.py" │ │ └── "model_based.py" │ │ └── "models.py" │ │ └── "multiagent.py" │ │ └── "rlhf.py" │ ├── "planners" │ │ └── "cem.py" │ │ └── "common.py" │ │ └── "mppi.py" │ └── "tensordict_module" │ └── "actors.py" │ └── "common.py" │ └── "exploration.py" │ └── "probabilistic.py" │ └── "rnn.py" │ └── "sequence.py" │ └── "world_models.py" ├── "objectives" │ └── "a2c.py" │ └── "common.py" │ └── "cql.py" │ └── "ddpg.py" │ └── "decision_transformer.py" │ └── "deprecated.py" │ └── "dqn.py" │ └── "dreamer.py" │ └── "functional.py" │ └── "iql.py" │ ├── "multiagent" │ │ └── "qmixer.py" │ └── "ppo.py" │ └── "redq.py" │ └── "reinforce.py" │ └── "sac.py" │ └── "td3.py" │ ├── "value" │ └── "advantages.py" │ └── "functional.py" │ └── "pg.py" ├── "record" │ ├── "loggers" │ │ └── "common.py" │ │ └── "csv.py" │ │ └── "mlflow.py" │ │ └── "tensorboard.py" │ │ └── "wandb.py" │ └── "recorder.py" ├── "trainers" │ │ │ ├── "helpers" │ │ └── "collectors.py" │ │ └── "envs.py" │ │ └── "logger.py" │ │ └── "losses.py" │ │ └── "models.py" │ │ └── "replay_buffer.py" │ │ └── "trainers.py" │ └── "trainers.py" └── "version.py" Unlike other domains, RL is less about media than *algorithms*. As such, it is harder to make truly independent components. What TorchRL is not: * a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms, but we provide these algorithms only as examples of how to use the library. * a research framework: modularity in TorchRL comes in two flavors. First, we try to build re-usable components, such that they can be easily swapped with each other. Second, we make our best such that components can be used independently of the rest of the library. TorchRL has very few core dependencies, predominantly PyTorch and numpy. All other dependencies (gym, torchvision, wandb / tensorboard) are optional. Data ---- TensorDict ~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 186-191 .. code-block:: Python import torch from tensordict import TensorDict .. GENERATED FROM PYTHON SOURCE LINES 214-216 Let's create a TensorDict. The constructor accepts many different formats, like passing a dict or with keyword arguments: .. GENERATED FROM PYTHON SOURCE LINES 216-225 .. code-block:: Python batch_size = 5 data = TensorDict( key1=torch.zeros(batch_size, 3), key2=torch.zeros(batch_size, 5, 6, dtype=torch.bool), batch_size=[batch_size], ) print(data) .. GENERATED FROM PYTHON SOURCE LINES 226-227 You can index a TensorDict along its ``batch_size``, as well as query keys. .. GENERATED FROM PYTHON SOURCE LINES 227-231 .. code-block:: Python print(data[2]) print(data["key1"] is data.get("key1")) .. GENERATED FROM PYTHON SOURCE LINES 232-233 The following shows how to stack multiple TensorDicts. This is particularly useful when writing rollout loops! .. GENERATED FROM PYTHON SOURCE LINES 233-253 .. code-block:: Python data1 = TensorDict( { "key1": torch.zeros(batch_size, 1), "key2": torch.zeros(batch_size, 5, 6, dtype=torch.bool), }, batch_size=[batch_size], ) data2 = TensorDict( { "key1": torch.ones(batch_size, 1), "key2": torch.ones(batch_size, 5, 6, dtype=torch.bool), }, batch_size=[batch_size], ) data = torch.stack([data1, data2], 0) data.batch_size, data["key1"] .. GENERATED FROM PYTHON SOURCE LINES 254-255 Here are some other functionalities of TensorDict: viewing, permute, sharing memory or expanding. .. GENERATED FROM PYTHON SOURCE LINES 255-280 .. code-block:: Python print( "view(-1): ", data.view(-1).batch_size, data.view(-1).get("key1").shape, ) print("to device: ", data.to("cpu")) # print("pin_memory: ", data.pin_memory()) print("share memory: ", data.share_memory_()) print( "permute(1, 0): ", data.permute(1, 0).batch_size, data.permute(1, 0).get("key1").shape, ) print( "expand: ", data.expand(3, *data.batch_size).batch_size, data.expand(3, *data.batch_size).get("key1").shape, ) .. GENERATED FROM PYTHON SOURCE LINES 281-282 You can create a **nested data** as well. .. GENERATED FROM PYTHON SOURCE LINES 282-295 .. code-block:: Python data = TensorDict( source={ "key1": torch.zeros(batch_size, 3), "key2": TensorDict( source={"sub_key1": torch.zeros(batch_size, 2, 1)}, batch_size=[batch_size, 2], ), }, batch_size=[batch_size], ) data .. GENERATED FROM PYTHON SOURCE LINES 296-302 Replay buffers -------------- :ref:`Replay buffers <ref_buffers>` are a crucial component in many RL algorithms. TorchRL provides a range of replay buffer implementations. Most basic features will work with any data scturcture (list, tuples, dict) but to use the replay buffers to their full extend and with fast read and write access, TensorDict APIs should be preferred. .. GENERATED FROM PYTHON SOURCE LINES 302-307 .. code-block:: Python from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer rb = ReplayBuffer(collate_fn=lambda x: x) .. GENERATED FROM PYTHON SOURCE LINES 308-310 Adding can be done with :meth:`~torchrl.data.ReplayBuffer.add` (n=1) or :meth:`~torchrl.data.ReplayBuffer.extend` (n>1). .. GENERATED FROM PYTHON SOURCE LINES 310-315 .. code-block:: Python rb.add(1) rb.sample(1) rb.extend([2, 3]) rb.sample(3) .. GENERATED FROM PYTHON SOURCE LINES 316-318 Prioritized Replay Buffers can also be used: .. GENERATED FROM PYTHON SOURCE LINES 318-324 .. code-block:: Python rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x) rb.add(1) rb.sample(1) rb.update_priority(1, 0.5) .. GENERATED FROM PYTHON SOURCE LINES 325-327 Here are examples of using a replaybuffer with data_stack. Using them makes it easy to abstract away the behaviour of the replay buffer for multiple use cases. .. GENERATED FROM PYTHON SOURCE LINES 327-356 .. code-block:: Python collate_fn = torch.stack rb = ReplayBuffer(collate_fn=collate_fn) rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[])) len(rb) rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) print(len(rb)) print(rb.sample(10)) print(rb.sample(2).contiguous()) torch.manual_seed(0) from torchrl.data import TensorDictPrioritizedReplayBuffer rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error") rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) data_sample = rb.sample(2).contiguous() print(data_sample) print(data_sample["index"]) data_sample["td_error"] = torch.rand(2) rb.update_tensordict_priority(data_sample) for i, val in enumerate(rb._sampler._sum_tree): print(i, val) if i == len(rb): break .. GENERATED FROM PYTHON SOURCE LINES 357-363 Envs ---- TorchRL provides a range of :ref:`environment <Environment-API>` wrappers and utilities. Gym Environment ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 363-378 .. code-block:: Python try: import gymnasium as gym except ModuleNotFoundError: import gym from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend gym_env = gym.make("Pendulum-v1") env = GymWrapper(gym_env) env = GymEnv("Pendulum-v1") data = env.reset() env.rand_step(data) .. GENERATED FROM PYTHON SOURCE LINES 379-382 Changing environments config ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 382-401 .. code-block:: Python env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) env.reset() env.close() del env from torchrl.envs import ( Compose, NoopResetEnv, ObservationNorm, ToTensorImage, TransformedEnv, ) base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage())) env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) .. GENERATED FROM PYTHON SOURCE LINES 402-407 Environment Transforms ~~~~~~~~~~~~~~~~~~~~~~ Transforms act like Gym wrappers but with an API closer to torchvision's ``torch.distributions``' transforms. There is a wide range of :ref:`transforms <transforms>` to choose from. .. GENERATED FROM PYTHON SOURCE LINES 407-426 .. code-block:: Python from torchrl.envs import ( Compose, NoopResetEnv, ObservationNorm, StepCounter, ToTensorImage, TransformedEnv, ) base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False) env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage())) env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) env.reset() print("env: ", env) print("last transform parent: ", env.transform[2].parent) .. GENERATED FROM PYTHON SOURCE LINES 427-432 Vectorized Environments ~~~~~~~~~~~~~~~~~~~~~~~ Vectorized / parallel environments can provide some significant speed-ups. .. GENERATED FROM PYTHON SOURCE LINES 432-458 .. code-block:: Python from torchrl.envs import ParallelEnv def make_env(): # You can control whether to use gym or gymnasium for your env with set_gym_backend("gym"): return GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) base_env = ParallelEnv( 4, make_env, mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__" ) env = TransformedEnv( base_env, Compose(StepCounter(), ToTensorImage()) ) # applies transforms on batch of envs env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) env.reset() print(env.action_spec) env.close() del env .. GENERATED FROM PYTHON SOURCE LINES 459-468 Modules ------- Multiple :ref:`modules <ref_modules>` (utils, models and wrappers) can be found in the library. Models ~~~~~~ Example of a MLP model: .. GENERATED FROM PYTHON SOURCE LINES 468-477 .. code-block:: Python from torch import nn from torchrl.modules import ConvNet, MLP from torchrl.modules.models.utils import SquashDims net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU) print(net) print(net(torch.randn(10, 3)).shape) .. GENERATED FROM PYTHON SOURCE LINES 478-480 Example of a CNN model: .. GENERATED FROM PYTHON SOURCE LINES 480-491 .. code-block:: Python cnn = ConvNet( num_cells=[32, 64], kernel_sizes=[8, 4], strides=[2, 1], aggregator_class=SquashDims, ) print(cnn) print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed .. GENERATED FROM PYTHON SOURCE LINES 492-497 TensorDictModules ~~~~~~~~~~~~~~~~~ :ref:`Some modules <tdmodules>` are specifically designed to work with tensordict inputs. .. GENERATED FROM PYTHON SOURCE LINES 497-506 .. code-block:: Python from tensordict.nn import TensorDictModule data = TensorDict({"key1": torch.randn(10, 3)}, batch_size=[10]) module = nn.Linear(3, 4) td_module = TensorDictModule(module, in_keys=["key1"], out_keys=["key2"]) td_module(data) print(data) .. GENERATED FROM PYTHON SOURCE LINES 507-512 Sequences of Modules ~~~~~~~~~~~~~~~~~~~~ Making sequences of modules is made easy by :class:`~tensordict.nn.TensorDictSequential`: .. GENERATED FROM PYTHON SOURCE LINES 512-544 .. code-block:: Python from tensordict.nn import TensorDictSequential backbone_module = nn.Linear(5, 3) backbone = TensorDictModule( backbone_module, in_keys=["observation"], out_keys=["hidden"] ) actor_module = nn.Linear(3, 4) actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"]) value_module = MLP(out_features=1, num_cells=[4, 5]) value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"]) sequence = TensorDictSequential(backbone, actor, value) print(sequence) print(sequence.in_keys, sequence.out_keys) data = TensorDict( {"observation": torch.randn(3, 5)}, [3], ) backbone(data) actor(data) value(data) data = TensorDict( {"observation": torch.randn(3, 5)}, [3], ) sequence(data) print(data) .. GENERATED FROM PYTHON SOURCE LINES 545-550 Functional Programming (Ensembling / Meta-RL) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Functional calls have never been easier. Extract the parameters with :func:`~tensordict.from_module`, and replace them with :meth:`~tensordict.TensorDict.to_module`: .. GENERATED FROM PYTHON SOURCE LINES 550-556 .. code-block:: Python from tensordict import from_module params = from_module(sequence) print("extracted params", params) .. GENERATED FROM PYTHON SOURCE LINES 557-558 functional call using tensordict: .. GENERATED FROM PYTHON SOURCE LINES 558-562 .. code-block:: Python with params.to_module(sequence): data = sequence(data) .. GENERATED FROM PYTHON SOURCE LINES 563-569 VMAP ~~~~ Fast execution of multiple copies of a similar architecture is key to train your models fast. :func:`~torch.vmap` is tailored to do just that: .. GENERATED FROM PYTHON SOURCE LINES 569-583 .. code-block:: Python from torch import vmap params_expand = params.expand(4) def exec_sequence(params, data): with params.to_module(sequence): return sequence(data) tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, data) print(tensordict_exp) .. GENERATED FROM PYTHON SOURCE LINES 584-588 Specialized Classes ~~~~~~~~~~~~~~~~~~~ TorchRL provides also some specialized modules that run checks on the output values. .. GENERATED FROM PYTHON SOURCE LINES 588-604 .. code-block:: Python torch.manual_seed(0) from torchrl.data import Bounded from torchrl.modules import SafeModule spec = Bounded(-torch.ones(3), torch.ones(3)) base_module = nn.Linear(5, 3) module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True ) data = TensorDict({"obs": torch.randn(5)}, batch_size=[]) module(data)["action"] data = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[]) module(data)["action"] # safe=True projects the result within the set .. GENERATED FROM PYTHON SOURCE LINES 605-607 The :class:`~torchrl.modules.Actor` class has has a predefined output key (``"action"``): .. GENERATED FROM PYTHON SOURCE LINES 607-620 .. code-block:: Python from torchrl.modules import Actor base_module = nn.Linear(5, 3) actor = Actor(base_module, in_keys=["obs"]) data = TensorDict({"obs": torch.randn(5)}, batch_size=[]) actor(data) # action is the default value from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, ) .. GENERATED FROM PYTHON SOURCE LINES 621-623 Working with probabilistic models is also made easy thanks to the ``tensordict.nn`` API: .. GENERATED FROM PYTHON SOURCE LINES 623-642 .. code-block:: Python from torchrl.modules import NormalParamExtractor, TanhNormal td = TensorDict({"input": torch.randn(3, 5)}, [3]) net = nn.Sequential( nn.Linear(5, 4), NormalParamExtractor() ) # splits the output in loc and scale module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"]) td_module = ProbabilisticTensorDictSequential( module, ProbabilisticTensorDictModule( in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, return_log_prob=False, ), ) td_module(td) print(td) .. GENERATED FROM PYTHON SOURCE LINES 643-658 .. code-block:: Python # returning the log-probability td = TensorDict({"input": torch.randn(3, 5)}, [3]) td_module = ProbabilisticTensorDictSequential( module, ProbabilisticTensorDictModule( in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, return_log_prob=True, ), ) td_module(td) print(td) .. GENERATED FROM PYTHON SOURCE LINES 659-662 Controlling randomness and sampling strategies is achieved via a context manager, :class:`~torchrl.envs.set_exploration_type`: .. GENERATED FROM PYTHON SOURCE LINES 662-675 .. code-block:: Python from torchrl.envs.utils import ExplorationType, set_exploration_type td = TensorDict({"input": torch.randn(3, 5)}, [3]) torch.manual_seed(0) with set_exploration_type(ExplorationType.RANDOM): td_module(td) print("random:", td["action"]) with set_exploration_type(ExplorationType.DETERMINISTIC): td_module(td) print("mode:", td["action"]) .. GENERATED FROM PYTHON SOURCE LINES 676-680 Using Environments and Modules ------------------------------ Let us see how environments and modules can be combined: .. GENERATED FROM PYTHON SOURCE LINES 680-708 .. code-block:: Python from torchrl.envs.utils import step_mdp env = GymEnv("Pendulum-v1") action_spec = env.action_spec actor_module = nn.Linear(3, 1) actor = SafeModule( actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"] ) torch.manual_seed(0) env.set_seed(0) max_steps = 100 data = env.reset() data_stack = TensorDict(batch_size=[max_steps]) for i in range(max_steps): actor(data) data_stack[i] = env.step(data) if data["done"].any(): break data = step_mdp(data) # roughly equivalent to obs = next_obs tensordicts_prealloc = data_stack.clone() print("total steps:", i) print(data_stack) .. GENERATED FROM PYTHON SOURCE LINES 709-727 .. code-block:: Python # equivalent torch.manual_seed(0) env.set_seed(0) max_steps = 100 data = env.reset() data_stack = [] for _ in range(max_steps): actor(data) data_stack.append(env.step(data)) if data["done"].any(): break data = step_mdp(data) # roughly equivalent to obs = next_obs tensordicts_stack = torch.stack(data_stack, 0) print("total steps:", i) print(tensordicts_stack) .. GENERATED FROM PYTHON SOURCE LINES 728-731 .. code-block:: Python (tensordicts_stack == tensordicts_prealloc).all() .. GENERATED FROM PYTHON SOURCE LINES 732-743 .. code-block:: Python torch.manual_seed(0) env.set_seed(0) tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps) tensordict_rollout (tensordict_rollout == tensordicts_prealloc).all() from tensordict.nn import TensorDictModule .. GENERATED FROM PYTHON SOURCE LINES 744-749 Collectors ---------- We also provide a set of :ref:`data collectors <ref_collectors>`, that automaticall gather as many frames per batch as required. They work from single-node, single worker to multi-nodes, multi-workers settings. .. GENERATED FROM PYTHON SOURCE LINES 749-755 .. code-block:: Python from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector from torchrl.envs import EnvCreator, SerialEnv from torchrl.envs.libs.gym import GymEnv .. GENERATED FROM PYTHON SOURCE LINES 756-762 EnvCreator makes sure that we can send a lambda function from process to process We use a :class:`~torchrl.envs.SerialEnv` for simplicity (single worker), but for larger jobs a :class:`~torchrl.envs.ParallelEnv` (multi-workers) would be better suited. .. note:: Multiprocessed envs and multiprocessed collectors can be combined! .. GENERATED FROM PYTHON SOURCE LINES 762-772 .. code-block:: Python parallel_env = SerialEnv( 3, EnvCreator(lambda: GymEnv("Pendulum-v1")), ) create_env_fn = [parallel_env, parallel_env] actor_module = nn.Linear(3, 1) actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"]) .. GENERATED FROM PYTHON SOURCE LINES 773-776 Sync multiprocessed data collector ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 776-788 .. code-block:: Python devices = ["cpu", "cpu"] collector = MultiSyncDataCollector( create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv policy=actor, total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) device=devices, ) .. GENERATED FROM PYTHON SOURCE LINES 789-798 .. code-block:: Python for i, d in enumerate(collector): if i == 0: print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) collector.shutdown() del collector .. GENERATED FROM PYTHON SOURCE LINES 799-805 Async multiprocessed data collector ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This class allows you to collect data while the model is training. This is particularily useful in off-policy settings as it decouples the inference and the model trainning. Data is delived in a first-ready-first-served basis (workers will queue their results): .. GENERATED FROM PYTHON SOURCE LINES 805-825 .. code-block:: Python collector = MultiaSyncDataCollector( create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv policy=actor, total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) device=devices, ) for i, d in enumerate(collector): if i == 0: print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) collector.shutdown() del collector del create_env_fn del parallel_env .. GENERATED FROM PYTHON SOURCE LINES 826-830 Objectives ---------- :ref:`Objectives <ref_objectives>` are the main entry points when coding up a new algorithm. .. GENERATED FROM PYTHON SOURCE LINES 830-850 .. code-block:: Python from torchrl.objectives import DDPGLoss actor_module = nn.Linear(3, 1) actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"]) class ConcatModule(nn.Linear): def forward(self, obs, action): return super().forward(torch.cat([obs, action], -1)) value_module = ConcatModule(4, 1) value = TensorDictModule( value_module, in_keys=["observation", "action"], out_keys=["state_action_value"] ) loss_fn = DDPGLoss(actor, value) loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99) .. GENERATED FROM PYTHON SOURCE LINES 851-871 .. code-block:: Python data = TensorDict( { "observation": torch.randn(10, 3), "next": { "observation": torch.randn(10, 3), "reward": torch.randn(10, 1), "done": torch.zeros(10, 1, dtype=torch.bool), }, "action": torch.randn(10, 1), }, batch_size=[10], device="cpu", ) loss_td = loss_fn(data) print(loss_td) print(data) .. GENERATED FROM PYTHON SOURCE LINES 872-888 Installing the Library ---------------------- The library is on PyPI: *pip install torchrl* See the `README <https://github.com/pytorch/rl/blob/main/README.md>`_ for more information. Contributing ------------ We are actively looking for contributors and early users. If you're working in RL (or just curious), try it! Give us feedback: what will make the success of TorchRL is how well it covers researchers needs. To do that, we need their input! Since the library is nascent, it is a great time for you to shape it the way you want! See the `Contributing guide <https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md>`_ for more info. .. _sphx_glr_download_tutorials_torchrl_demo.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torchrl_demo.ipynb <torchrl_demo.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torchrl_demo.py <torchrl_demo.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: torchrl_demo.zip <torchrl_demo.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_