.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/coding_ddpg.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_coding_ddpg.py: TorchRL objectives: Coding a DDPG loss ====================================== **Author**: `Vincent Moens `_ .. _coding_ddpg: .. GENERATED FROM PYTHON SOURCE LINES 12-67 Overview -------- TorchRL separates the training of RL sota-implementations in various pieces that will be assembled in your training script: the environment, the data collection and storage, the model and finally the loss function. TorchRL losses (or "objectives") are stateful objects that contain the trainable parameters (policy and value models). This tutorial will guide you through the steps to code a loss from the ground up using TorchRL. To this aim, we will be focusing on DDPG, which is a relatively straightforward algorithm to code. `Deep Deterministic Policy Gradient `_ (DDPG) is a simple continuous control algorithm. It consists in learning a parametric value function for an action-observation pair, and then learning a policy that outputs actions that maximize this value function given a certain observation. What you will learn: - how to write a loss module and customize its value estimator; - how to build an environment in TorchRL, including transforms (for example, data normalization) and parallel execution; - how to design a policy and value network; - how to collect data from your environment efficiently and store them in a replay buffer; - how to store trajectories (and not transitions) in your replay buffer); - how to evaluate your model. Prerequisites ~~~~~~~~~~~~~ This tutorial assumes that you have completed the `PPO tutorial `_ which gives an overview of the TorchRL components and dependencies, such as :class:`tensordict.TensorDict` and :class:`tensordict.nn.TensorDictModules`, although it should be sufficiently transparent to be understood without a deep understanding of these classes. .. note:: We do not aim at giving a SOTA implementation of the algorithm, but rather to provide a high-level illustration of TorchRL's loss implementations and the library features that are to be used in the context of this algorithm. Imports and setup ----------------- .. code-block:: bash %%bash pip3 install torchrl mujoco glfw .. GENERATED FROM PYTHON SOURCE LINES 67-74 .. code-block:: Python import torch import tqdm .. GENERATED FROM PYTHON SOURCE LINES 95-96 We will execute the policy on CUDA if available .. GENERATED FROM PYTHON SOURCE LINES 96-104 .. code-block:: Python is_fork = multiprocessing.get_start_method() == "fork" device = ( torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") ) collector_device = torch.device("cpu") # Change the device to ``cuda`` to use CUDA .. GENERATED FROM PYTHON SOURCE LINES 105-191 TorchRL :class:`~torchrl.objectives.LossModule` ----------------------------------------------- TorchRL provides a series of losses to use in your training scripts. The aim is to have losses that are easily reusable/swappable and that have a simple signature. The main characteristics of TorchRL losses are: - They are stateful objects: they contain a copy of the trainable parameters such that ``loss_module.parameters()`` gives whatever is needed to train the algorithm. - They follow the ``TensorDict`` convention: the :meth:`torch.nn.Module.forward` method will receive a TensorDict as input that contains all the necessary information to return a loss value. >>> data = replay_buffer.sample() >>> loss_dict = loss_module(data) - They output a :class:`tensordict.TensorDict` instance with the loss values written under a ``"loss_"`` where ``smth`` is a string describing the loss. Additional keys in the ``TensorDict`` may be useful metrics to log during training time. .. note:: The reason we return independent losses is to let the user use a different optimizer for different sets of parameters for instance. Summing the losses can be simply done via >>> loss_val = sum(loss for key, loss in loss_dict.items() if key.startswith("loss_")) The ``__init__`` method ~~~~~~~~~~~~~~~~~~~~~~~ The parent class of all losses is :class:`~torchrl.objectives.LossModule`. As many other components of the library, its :meth:`~torchrl.objectives.LossModule.forward` method expects as input a :class:`tensordict.TensorDict` instance sampled from an experience replay buffer, or any similar data structure. Using this format makes it possible to re-use the module across modalities, or in complex settings where the model needs to read multiple entries for instance. In other words, it allows us to code a loss module that is oblivious to the data type that is being given to is and that focuses on running the elementary steps of the loss function and only those. To keep the tutorial as didactic as we can, we'll be displaying each method of the class independently and we'll be populating the class at a later stage. Let us start with the :meth:`~torchrl.objectives.LossModule.__init__` method. DDPG aims at solving a control task with a simple strategy: training a policy to output actions that maximize the value predicted by a value network. Hence, our loss module needs to receive two networks in its constructor: an actor and a value networks. We expect both of these to be TensorDict-compatible objects, such as :class:`tensordict.nn.TensorDictModule`. Our loss function will need to compute a target value and fit the value network to this, and generate an action and fit the policy such that its value estimate is maximized. The crucial step of the :meth:`LossModule.__init__` method is the call to :meth:`~torchrl.LossModule.convert_to_functional`. This method will extract the parameters from the module and convert it to a functional module. Strictly speaking, this is not necessary and one may perfectly code all the losses without it. However, we encourage its usage for the following reason. The reason TorchRL does this is that RL sota-implementations often execute the same model with different sets of parameters, called "trainable" and "target" parameters. The "trainable" parameters are those that the optimizer needs to fit. The "target" parameters are usually a copy of the former's with some time lag (absolute or diluted through a moving average). These target parameters are used to compute the value associated with the next observation. One the advantages of using a set of target parameters for the value model that do not match exactly the current configuration is that they provide a pessimistic bound on the value function being computed. Pay attention to the ``create_target_params`` keyword argument below: this argument tells the :meth:`~torchrl.objectives.LossModule.convert_to_functional` method to create a set of target parameters in the loss module to be used for target value computation. If this is set to ``False`` (see the actor network for instance) the ``target_actor_network_params`` attribute will still be accessible but this will just return a **detached** version of the actor parameters. Later, we will see how the target parameters should be updated in TorchRL. .. GENERATED FROM PYTHON SOURCE LINES 191-223 .. code-block:: Python from tensordict.nn import TensorDictModule def _init( self, actor_network: TensorDictModule, value_network: TensorDictModule, ) -> None: super(type(self), self).__init__() self.convert_to_functional( actor_network, "actor_network", create_target_params=True, ) self.convert_to_functional( value_network, "value_network", create_target_params=True, compare_against=list(actor_network.parameters()), ) self.actor_in_keys = actor_network.in_keys # Since the value we'll be using is based on the actor and value network, # we put them together in a single actor-critic container. actor_critic = ActorCriticWrapper(actor_network, value_network) self.actor_critic = actor_critic self.loss_function = "l2" .. GENERATED FROM PYTHON SOURCE LINES 224-240 The value estimator loss method ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In many RL algorithm, the value network (or Q-value network) is trained based on an empirical value estimate. This can be bootstrapped (TD(0), low variance, high bias), meaning that the target value is obtained using the next reward and nothing else, or a Monte-Carlo estimate can be obtained (TD(1)) in which case the whole sequence of upcoming rewards will be used (high variance, low bias). An intermediate estimator (TD(:math:`\lambda`)) can also be used to compromise bias and variance. TorchRL makes it easy to use one or the other estimator via the :class:`~torchrl.objectives.utils.ValueEstimators` Enum class, which contains pointers to all the value estimators implemented. Let us define the default value function here. We will take the simplest version (TD(0)), and show later on how this can be changed. .. GENERATED FROM PYTHON SOURCE LINES 240-245 .. code-block:: Python from torchrl.objectives.utils import ValueEstimators default_value_estimator = ValueEstimators.TD0 .. GENERATED FROM PYTHON SOURCE LINES 246-249 We also need to give some instructions to DDPG on how to build the value estimator, depending on the user query. Depending on the estimator provided, we will build the corresponding module to be used at train time: .. GENERATED FROM PYTHON SOURCE LINES 249-275 .. code-block:: Python from torchrl.objectives.utils import default_value_kwargs from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): hp = dict(default_value_kwargs(value_type)) if hasattr(self, "gamma"): hp["gamma"] = self.gamma hp.update(hyperparams) value_key = "state_action_value" if value_type == ValueEstimators.TD1: self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.TD0: self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp) elif value_type == ValueEstimators.GAE: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") self._value_estimator.set_keys(value=value_key) .. GENERATED FROM PYTHON SOURCE LINES 276-292 The ``make_value_estimator`` method can but does not need to be called: ifgg not, the :class:`~torchrl.objectives.LossModule` will query this method with its default estimator. The actor loss method ~~~~~~~~~~~~~~~~~~~~~ The central piece of an RL algorithm is the training loss for the actor. In the case of DDPG, this function is quite simple: we just need to compute the value associated with an action computed using the policy and optimize the actor weights to maximize this value. When computing this value, we must make sure to take the value parameters out of the graph, otherwise the actor and value loss will be mixed up. For this, the :func:`~torchrl.objectives.utils.hold_out_params` function can be used. .. GENERATED FROM PYTHON SOURCE LINES 292-308 .. code-block:: Python def _loss_actor( self, tensordict, ) -> torch.Tensor: td_copy = tensordict.select(*self.actor_in_keys) # Get an action from the actor network: since we made it functional, we need to pass the params with self.actor_network_params.to_module(self.actor_network): td_copy = self.actor_network(td_copy) # get the value associated with that action with self.value_network_params.detach().to_module(self.value_network): td_copy = self.value_network(td_copy) return -td_copy.get("state_action_value") .. GENERATED FROM PYTHON SOURCE LINES 309-315 The value loss method ~~~~~~~~~~~~~~~~~~~~~ We now need to optimize our value network parameters. To do this, we will rely on the value estimator of our class: .. GENERATED FROM PYTHON SOURCE LINES 315-352 .. code-block:: Python from torchrl.objectives.utils import distance_loss def _loss_value( self, tensordict, ): td_copy = tensordict.clone() # V(s, a) with self.value_network_params.to_module(self.value_network): self.value_network(td_copy) pred_val = td_copy.get("state_action_value").squeeze(-1) # we manually reconstruct the parameters of the actor-critic, where the first # set of parameters belongs to the actor and the second to the value function. target_params = TensorDict( { "module": { "0": self.target_actor_network_params, "1": self.target_value_network_params, } }, batch_size=self.target_actor_network_params.batch_size, device=self.target_actor_network_params.device, ) with target_params.to_module(self.value_estimator): target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function` loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function) td_error = (pred_val - target_value).pow(2) return loss_value, td_error, pred_val, target_value .. GENERATED FROM PYTHON SOURCE LINES 353-359 Putting things together in a forward call ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The only missing piece is the forward method, which will glue together the value and actor loss, collect the cost values and write them in a ``TensorDict`` delivered to the user. .. GENERATED FROM PYTHON SOURCE LINES 359-403 .. code-block:: Python from tensordict import TensorDict, TensorDictBase def _forward(self, input_tensordict: TensorDictBase) -> TensorDict: loss_value, td_error, pred_val, target_value = self.loss_value( input_tensordict, ) td_error = td_error.detach() td_error = td_error.unsqueeze(input_tensordict.ndimension()) if input_tensordict.device is not None: td_error = td_error.to(input_tensordict.device) input_tensordict.set( "td_error", td_error, inplace=True, ) loss_actor = self.loss_actor(input_tensordict) return TensorDict( source={ "loss_actor": loss_actor.mean(), "loss_value": loss_value.mean(), "pred_value": pred_val.mean().detach(), "target_value": target_value.mean().detach(), "pred_value_max": pred_val.max().detach(), "target_value_max": target_value.max().detach(), }, batch_size=[], ) from torchrl.objectives import LossModule class DDPGLoss(LossModule): default_value_estimator = default_value_estimator make_value_estimator = make_value_estimator __init__ = _init forward = _forward loss_value = _loss_value loss_actor = _loss_actor .. GENERATED FROM PYTHON SOURCE LINES 404-442 Now that we have our loss, we can use it to train a policy to solve a control task. Environment ----------- In most sota-implementations, the first thing that needs to be taken care of is the construction of the environment as it conditions the remainder of the training script. For this example, we will be using the ``"cheetah"`` task. The goal is to make a half-cheetah run as fast as possible. In TorchRL, one can create such a task by relying on ``dm_control`` or ``gym``: .. code-block:: python env = GymEnv("HalfCheetah-v4") or .. code-block:: python env = DMControlEnv("cheetah", "run") By default, these environment disable rendering. Training from states is usually easier than training from images. To keep things simple, we focus on learning from states only. To pass the pixels to the ``tensordicts`` that are collected by :func:`env.step()`, simply pass the ``from_pixels=True`` argument to the constructor: .. code-block:: python env = GymEnv("HalfCheetah-v4", from_pixels=True, pixels_only=True) We write a :func:`make_env` helper function that will create an environment with either one of the two backends considered above (``dm-control`` or ``gym``). .. GENERATED FROM PYTHON SOURCE LINES 442-477 .. code-block:: Python from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.libs.gym import GymEnv env_library = None env_name = None def make_env(from_pixels=False): """Create a base ``env``.""" global env_library global env_name if backend == "dm_control": env_name = "cheetah" env_task = "run" env_args = (env_name, env_task) env_library = DMControlEnv elif backend == "gym": env_name = "HalfCheetah-v4" env_args = (env_name,) env_library = GymEnv else: raise NotImplementedError env_kwargs = { "device": device, "from_pixels": from_pixels, "pixels_only": from_pixels, "frame_skip": 2, } env = env_library(*env_args, **env_kwargs) return env .. GENERATED FROM PYTHON SOURCE LINES 478-505 Transforms ~~~~~~~~~~ Now that we have a base environment, we may want to modify its representation to make it more policy-friendly. In TorchRL, transforms are appended to the base environment in a specialized :class:`torchr.envs.TransformedEnv` class. - It is common in DDPG to rescale the reward using some heuristic value. We will multiply the reward by 5 in this example. - If we are using :mod:`dm_control`, it is also important to build an interface between the simulator which works with double precision numbers, and our script which presumably uses single precision ones. This transformation goes both ways: when calling :func:`env.step`, our actions will need to be represented in double precision, and the output will need to be transformed to single precision. The :class:`~torchrl.envs.DoubleToFloat` transform does exactly this: the ``in_keys`` list refers to the keys that will need to be transformed from double to float, while the ``in_keys_inv`` refers to those that need to be transformed to double before being passed to the environment. - We concatenate the state keys together using the :class:`~torchrl.envs.CatTensors` transform. - Finally, we also leave the possibility of normalizing the states: we will take care of computing the normalizing constants later on. .. GENERATED FROM PYTHON SOURCE LINES 505-557 .. code-block:: Python from torchrl.envs import ( CatTensors, DoubleToFloat, EnvCreator, InitTracker, ObservationNorm, ParallelEnv, RewardScaling, StepCounter, TransformedEnv, ) def make_transformed_env( env, ): """Apply transforms to the ``env`` (such as reward scaling and state normalization).""" env = TransformedEnv(env) # we append transforms one by one, although we might as well create the # transformed environment using the `env = TransformedEnv(base_env, transforms)` # syntax. env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling)) # We concatenate all states into a single "observation_vector" # even if there is a single tensor, it'll be renamed in "observation_vector". # This facilitates the downstream operations as we know the name of the # output tensor. # In some environments (not half-cheetah), there may be more than one # observation vector: in this case this code snippet will concatenate them # all. selected_keys = list(env.observation_spec.keys()) out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) # we normalize the states, but for now let's just instantiate a stateless # version of the transform env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True)) env.append_transform(DoubleToFloat()) env.append_transform(StepCounter(max_frames_per_traj)) # We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU) # exploration: env.append_transform(InitTracker()) return env .. GENERATED FROM PYTHON SOURCE LINES 558-582 Parallel execution ~~~~~~~~~~~~~~~~~~ The following helper function allows us to run environments in parallel. Running environments in parallel can significantly speed up the collection throughput. When using transformed environment, we need to choose whether we want to execute the transform individually for each environment, or centralize the data and transform it in batch. Both approaches are easy to code: .. code-block:: python env = ParallelEnv( lambda: TransformedEnv(GymEnv("HalfCheetah-v4"), transforms), num_workers=4 ) env = TransformedEnv( ParallelEnv(lambda: GymEnv("HalfCheetah-v4"), num_workers=4), transforms ) To leverage the vectorization capabilities of PyTorch, we adopt the first method: .. GENERATED FROM PYTHON SOURCE LINES 582-617 .. code-block:: Python def parallel_env_constructor( env_per_collector, transform_state_dict, ): if env_per_collector == 1: def make_t_env(): env = make_transformed_env(make_env()) env.transform[2].init_stats(3) env.transform[2].loc.copy_(transform_state_dict["loc"]) env.transform[2].scale.copy_(transform_state_dict["scale"]) return env env_creator = EnvCreator(make_t_env) return env_creator parallel_env = ParallelEnv( num_workers=env_per_collector, create_env_fn=EnvCreator(lambda: make_env()), create_env_kwargs=None, pin_memory=False, ) env = make_transformed_env(parallel_env) # we call `init_stats` for a limited number of steps, just to instantiate # the lazy buffers. env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1]) env.transform[2].load_state_dict(transform_state_dict) return env # The backend can be ``gym`` or ``dm_control`` backend = "gym" .. GENERATED FROM PYTHON SOURCE LINES 618-633 .. note:: ``frame_skip`` batches multiple step together with a single action If > 1, the other frame counts (for example, frames_per_batch, total_frames) need to be adjusted to have a consistent total number of frames collected across experiments. This is important as raising the frame-skip but keeping the total number of frames unchanged may seem like cheating: all things compared, a dataset of 10M elements collected with a frame-skip of 2 and another with a frame-skip of 1 actually have a ratio of interactions with the environment of 2:1! In a nutshell, one should be cautious about the frame-count of a training script when dealing with frame skipping as this may lead to biased comparisons between training strategies. Scaling the reward helps us control the signal magnitude for a more efficient learning. .. GENERATED FROM PYTHON SOURCE LINES 633-635 .. code-block:: Python reward_scaling = 5.0 .. GENERATED FROM PYTHON SOURCE LINES 636-638 We also define when a trajectory will be truncated. A thousand steps (500 if frame-skip = 2) is a good number to use for the cheetah task: .. GENERATED FROM PYTHON SOURCE LINES 638-641 .. code-block:: Python max_frames_per_traj = 500 .. GENERATED FROM PYTHON SOURCE LINES 642-652 Normalization of the observations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To compute the normalizing statistics, we run an arbitrary number of random steps in the environment and compute the mean and standard deviation of the collected observations. The :func:`ObservationNorm.init_stats()` method can be used for this purpose. To get the summary statistics, we create a dummy environment and run it for a given number of steps, collect data over a given number of steps and compute its summary statistics. .. GENERATED FROM PYTHON SOURCE LINES 652-664 .. code-block:: Python def get_env_stats(): """Gets the stats of an environment.""" proof_env = make_transformed_env(make_env()) t = proof_env.transform[2] t.init_stats(init_env_steps) transform_state_dict = t.state_dict() proof_env.close() return transform_state_dict .. GENERATED FROM PYTHON SOURCE LINES 665-668 Normalization stats ~~~~~~~~~~~~~~~~~~~ Number of random steps used as for stats computation using ``ObservationNorm`` .. GENERATED FROM PYTHON SOURCE LINES 668-673 .. code-block:: Python init_env_steps = 5000 transform_state_dict = get_env_stats() .. GENERATED FROM PYTHON SOURCE LINES 674-675 Number of environments in each data collector .. GENERATED FROM PYTHON SOURCE LINES 675-677 .. code-block:: Python env_per_collector = 4 .. GENERATED FROM PYTHON SOURCE LINES 678-680 We pass the stats computed earlier to normalize the output of our environment: .. GENERATED FROM PYTHON SOURCE LINES 680-689 .. code-block:: Python parallel_env = parallel_env_constructor( env_per_collector=env_per_collector, transform_state_dict=transform_state_dict, ) from torchrl.data import CompositeSpec .. GENERATED FROM PYTHON SOURCE LINES 690-721 Building the model ------------------ We now turn to the setup of the model. As we have seen, DDPG requires a value network, trained to estimate the value of a state-action pair, and a parametric actor that learns how to select actions that maximize this value. Recall that building a TorchRL module requires two steps: - writing the :class:`torch.nn.Module` that will be used as network, - wrapping the network in a :class:`tensordict.nn.TensorDictModule` where the data flow is handled by specifying the input and output keys. In more complex scenarios, :class:`tensordict.nn.TensorDictSequential` can also be used. The Q-Value network is wrapped in a :class:`~torchrl.modules.ValueOperator` that automatically sets the ``out_keys`` to ``"state_action_value`` for q-value networks and ``state_value`` for other value networks. TorchRL provides a built-in version of the DDPG networks as presented in the original paper. These can be found under :class:`~torchrl.modules.DdpgMlpActor` and :class:`~torchrl.modules.DdpgMlpQNet`. Since we use lazy modules, it is necessary to materialize the lazy modules before being able to move the policy from device to device and achieve other operations. Hence, it is good practice to run the modules with a small sample of data. For this purpose, we generate fake data from the environment specs. .. GENERATED FROM PYTHON SOURCE LINES 721-781 .. code-block:: Python from torchrl.modules import ( ActorCriticWrapper, DdpgMlpActor, DdpgMlpQNet, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, TanhDelta, ValueOperator, ) def make_ddpg_actor( transform_state_dict, device="cpu", ): proof_environment = make_transformed_env(make_env()) proof_environment.transform[2].init_stats(3) proof_environment.transform[2].load_state_dict(transform_state_dict) out_features = proof_environment.action_spec.shape[-1] actor_net = DdpgMlpActor( action_dim=out_features, ) in_keys = ["observation_vector"] out_keys = ["param"] actor = TensorDictModule( actor_net, in_keys=in_keys, out_keys=out_keys, ) actor = ProbabilisticActor( actor, distribution_class=TanhDelta, in_keys=["param"], spec=CompositeSpec(action=proof_environment.action_spec), ).to(device) q_net = DdpgMlpQNet() in_keys = in_keys + ["action"] qnet = ValueOperator( in_keys=in_keys, module=q_net, ).to(device) # initialize lazy modules qnet(actor(proof_environment.reset().to(device))) return actor, qnet actor, qnet = make_ddpg_actor( transform_state_dict=transform_state_dict, device=device, ) .. GENERATED FROM PYTHON SOURCE LINES 782-788 Exploration ~~~~~~~~~~~ The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper` exploration module, as suggested in the original paper. Let's define the number of frames before OU noise reaches its minimum value .. GENERATED FROM PYTHON SOURCE LINES 788-798 .. code-block:: Python annealing_frames = 1_000_000 actor_model_explore = OrnsteinUhlenbeckProcessWrapper( actor, annealing_num_steps=annealing_frames, ).to(device) if device == torch.device("cpu"): actor_model_explore.share_memory() .. GENERATED FROM PYTHON SOURCE LINES 799-843 Data collector -------------- TorchRL provides specialized classes to help you collect data by executing the policy in the environment. These "data collectors" iteratively compute the action to be executed at a given time, then execute a step in the environment and reset it when required. Data collectors are designed to help developers have a tight control on the number of frames per batch of data, on the (a)sync nature of this collection and on the resources allocated to the data collection (for example GPU, number of workers, and so on). Here we will use :class:`~torchrl.collectors.SyncDataCollector`, a simple, single-process data collector. TorchRL offers other collectors, such as :class:`~torchrl.collectors.MultiaSyncDataCollector`, which executed the rollouts in an asynchronous manner (for example, data will be collected while the policy is being optimized, thereby decoupling the training and data collection). The parameters to specify are: - an environment factory or an environment, - the policy, - the total number of frames before the collector is considered empty, - the maximum number of frames per trajectory (useful for non-terminating environments, like ``dm_control`` ones). .. note:: The ``max_frames_per_traj`` passed to the collector will have the effect of registering a new :class:`~torchrl.envs.StepCounter` transform with the environment used for inference. We can achieve the same result manually, as we do in this script. One should also pass: - the number of frames in each batch collected, - the number of random steps executed independently from the policy, - the devices used for policy execution - the devices used to store data before the data is passed to the main process. The total frames we will use during training should be around 1M. .. GENERATED FROM PYTHON SOURCE LINES 843-845 .. code-block:: Python total_frames = 10_000 # 1_000_000 .. GENERATED FROM PYTHON SOURCE LINES 846-854 The number of frames returned by the collector at each iteration of the outer loop is equal to the length of each sub-trajectories times the number of environments run in parallel in each collector. In other words, we expect batches from the collector to have a shape ``[env_per_collector, traj_len]`` where ``traj_len=frames_per_batch/env_per_collector``: .. GENERATED FROM PYTHON SOURCE LINES 854-874 .. code-block:: Python traj_len = 200 frames_per_batch = env_per_collector * traj_len init_random_frames = 5000 num_collectors = 2 from torchrl.collectors import SyncDataCollector from torchrl.envs import ExplorationType collector = SyncDataCollector( parallel_env, policy=actor_model_explore, total_frames=total_frames, frames_per_batch=frames_per_batch, init_random_frames=init_random_frames, reset_at_each_iter=False, split_trajs=False, device=collector_device, exploration_type=ExplorationType.RANDOM, ) .. GENERATED FROM PYTHON SOURCE LINES 875-885 Evaluator: building your recorder object ---------------------------------------- As the training data is obtained using some exploration strategy, the true performance of our algorithm needs to be assessed in deterministic mode. We do this using a dedicated class, ``Recorder``, which executes the policy in the environment at a given frequency and returns some statistics obtained from these simulations. The following helper function builds this object: .. GENERATED FROM PYTHON SOURCE LINES 885-906 .. code-block:: Python from torchrl.trainers import Recorder def make_recorder(actor_model_explore, transform_state_dict, record_interval): base_env = make_env() environment = make_transformed_env(base_env) environment.transform[2].init_stats( 3 ) # must be instantiated to load the state dict environment.transform[2].load_state_dict(transform_state_dict) recorder_obj = Recorder( record_frames=1000, policy_exploration=actor_model_explore, environment=environment, exploration_type=ExplorationType.MEAN, record_interval=record_interval, ) return recorder_obj .. GENERATED FROM PYTHON SOURCE LINES 907-908 We will be recording the performance every 10 batch collected .. GENERATED FROM PYTHON SOURCE LINES 908-921 .. code-block:: Python record_interval = 10 recorder = make_recorder( actor_model_explore, transform_state_dict, record_interval=record_interval ) from torchrl.data.replay_buffers import ( LazyMemmapStorage, PrioritizedSampler, RandomSampler, TensorDictReplayBuffer, ) .. GENERATED FROM PYTHON SOURCE LINES 922-935 Replay buffer ------------- Replay buffers come in two flavors: prioritized (where some error signal is used to give a higher likelihood of sampling to some items than others) and regular, circular experience replay. TorchRL replay buffers are composable: one can pick up the storage, sampling and writing strategies. It is also possible to store tensors on physical memory using a memory-mapped array. The following function takes care of creating the replay buffer with the desired hyperparameters: .. GENERATED FROM PYTHON SOURCE LINES 935-962 .. code-block:: Python from torchrl.envs import RandomCropTensorDict def make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb=False): if prb: sampler = PrioritizedSampler( max_capacity=buffer_size, alpha=0.7, beta=0.5, ) else: sampler = RandomSampler() replay_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage( buffer_size, scratch_dir=buffer_scratch_dir, ), batch_size=batch_size, sampler=sampler, pin_memory=False, prefetch=prefetch, transform=RandomCropTensorDict(random_crop_len, sample_dim=1), ) return replay_buffer .. GENERATED FROM PYTHON SOURCE LINES 963-964 We'll store the replay buffer in a temporary directory on disk .. GENERATED FROM PYTHON SOURCE LINES 964-970 .. code-block:: Python import tempfile tmpdir = tempfile.TemporaryDirectory() buffer_scratch_dir = tmpdir.name .. GENERATED FROM PYTHON SOURCE LINES 971-992 Replay buffer storage and batch size ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TorchRL replay buffer counts the number of elements along the first dimension. Since we'll be feeding trajectories to our buffer, we need to adapt the buffer size by dividing it by the length of the sub-trajectories yielded by our data collector. Regarding the batch-size, our sampling strategy will consist in sampling trajectories of length ``traj_len=200`` before selecting sub-trajectories or length ``random_crop_len=25`` on which the loss will be computed. This strategy balances the choice of storing whole trajectories of a certain length with the need for providing samples with a sufficient heterogeneity to our loss. The following figure shows the dataflow from a collector that gets 8 frames in each batch with 2 environments run in parallel, feeds them to a replay buffer that contains 1000 trajectories and samples sub-trajectories of 2 time steps each. .. figure:: /_static/img/replaybuffer_traj.png :alt: Storing trajectories in the replay buffer Let's start with the number of frames stored in the buffer .. GENERATED FROM PYTHON SOURCE LINES 992-1001 .. code-block:: Python def ceil_div(x, y): return -x // (-y) buffer_size = 1_000_000 buffer_size = ceil_div(buffer_size, traj_len) .. GENERATED FROM PYTHON SOURCE LINES 1002-1003 Prioritized replay buffer is disabled by default .. GENERATED FROM PYTHON SOURCE LINES 1003-1005 .. code-block:: Python prb = False .. GENERATED FROM PYTHON SOURCE LINES 1006-1008 We also need to define how many updates we'll be doing per batch of data collected. This is known as the update-to-data or ``UTD`` ratio: .. GENERATED FROM PYTHON SOURCE LINES 1008-1010 .. code-block:: Python update_to_data = 64 .. GENERATED FROM PYTHON SOURCE LINES 1011-1012 We'll be feeding the loss with trajectories of length 25: .. GENERATED FROM PYTHON SOURCE LINES 1012-1014 .. code-block:: Python random_crop_len = 25 .. GENERATED FROM PYTHON SOURCE LINES 1015-1019 In the original paper, the authors perform one update with a batch of 64 elements for each frame collected. Here, we reproduce the same ratio but while realizing several updates at each batch collection. We adapt our batch-size to achieve the same number of update-per-frame ratio: .. GENERATED FROM PYTHON SOURCE LINES 1019-1030 .. code-block:: Python batch_size = ceil_div(64 * frames_per_batch, update_to_data * random_crop_len) replay_buffer = make_replay_buffer( buffer_size=buffer_size, batch_size=batch_size, random_crop_len=random_crop_len, prefetch=3, prb=prb, ) .. GENERATED FROM PYTHON SOURCE LINES 1031-1038 Loss module construction ------------------------ We build our loss module with the actor and ``qnet`` we've just created. Because we have target parameters to update, we _must_ create a target network updater. .. GENERATED FROM PYTHON SOURCE LINES 1038-1045 .. code-block:: Python gamma = 0.99 lmbda = 0.9 tau = 0.001 # Decay factor for the target network loss_module = DDPGLoss(actor, qnet) .. GENERATED FROM PYTHON SOURCE LINES 1046-1047 let's use the TD(lambda) estimator! .. GENERATED FROM PYTHON SOURCE LINES 1047-1049 .. code-block:: Python loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda) .. GENERATED FROM PYTHON SOURCE LINES 1050-1068 .. note:: Off-policy usually dictates a TD(0) estimator. Here, we use a TD(:math:`\lambda`) estimator, which will introduce some bias as the trajectory that follows a certain state has been collected with an outdated policy. This trick, as the multi-step trick that can be used during data collection, are alternative versions of "hacks" that we usually find to work well in practice despite the fact that they introduce some bias in the return estimates. Target network updater ~~~~~~~~~~~~~~~~~~~~~~ Target networks are a crucial part of off-policy RL sota-implementations. Updating the target network parameters is made easy thanks to the :class:`~torchrl.objectives.HardUpdate` and :class:`~torchrl.objectives.SoftUpdate` classes. They're built with the loss module as argument, and the update is achieved via a call to `updater.step()` at the appropriate location in the training loop. .. GENERATED FROM PYTHON SOURCE LINES 1068-1073 .. code-block:: Python from torchrl.objectives.utils import SoftUpdate target_net_updater = SoftUpdate(loss_module, eps=1 - tau) .. GENERATED FROM PYTHON SOURCE LINES 1074-1078 Optimizer ~~~~~~~~~ Finally, we will use the Adam optimizer for the policy and value network: .. GENERATED FROM PYTHON SOURCE LINES 1078-1089 .. code-block:: Python from torch import optim optimizer_actor = optim.Adam( loss_module.actor_network_params.values(True, True), lr=1e-4, weight_decay=0.0 ) optimizer_value = optim.Adam( loss_module.value_network_params.values(True, True), lr=1e-3, weight_decay=1e-2 ) total_collection_steps = total_frames // frames_per_batch .. GENERATED FROM PYTHON SOURCE LINES 1090-1096 Time to train the policy ------------------------ The training loop is pretty straightforward now that we have built all the modules we need. .. GENERATED FROM PYTHON SOURCE LINES 1096-1182 .. code-block:: Python rewards = [] rewards_eval = [] # Main loop collected_frames = 0 pbar = tqdm.tqdm(total=total_frames) r0 = None for i, tensordict in enumerate(collector): # update weights of the inference policy collector.update_policy_weights_() if r0 is None: r0 = tensordict["next", "reward"].mean().item() pbar.update(tensordict.numel()) # extend the replay buffer with the new data current_frames = tensordict.numel() collected_frames += current_frames replay_buffer.extend(tensordict.cpu()) # optimization steps if collected_frames >= init_random_frames: for _ in range(update_to_data): # sample from replay buffer sampled_tensordict = replay_buffer.sample().to(device) # Compute loss loss_dict = loss_module(sampled_tensordict) # optimize loss_dict["loss_actor"].backward() gn1 = torch.nn.utils.clip_grad_norm_( loss_module.actor_network_params.values(True, True), 10.0 ) optimizer_actor.step() optimizer_actor.zero_grad() loss_dict["loss_value"].backward() gn2 = torch.nn.utils.clip_grad_norm_( loss_module.value_network_params.values(True, True), 10.0 ) optimizer_value.step() optimizer_value.zero_grad() gn = (gn1**2 + gn2**2) ** 0.5 # update priority if prb: replay_buffer.update_tensordict_priority(sampled_tensordict) # update target network target_net_updater.step() rewards.append( ( i, tensordict["next", "reward"].mean().item(), ) ) td_record = recorder(None) if td_record is not None: rewards_eval.append((i, td_record["r_evaluation"].item())) if len(rewards_eval) and collected_frames >= init_random_frames: target_value = loss_dict["target_value"].item() loss_value = loss_dict["loss_value"].item() loss_actor = loss_dict["loss_actor"].item() rn = sampled_tensordict["next", "reward"].mean().item() rs = sampled_tensordict["next", "reward"].std().item() pbar.set_description( f"reward: {rewards[-1][1]: 4.2f} (r0 = {r0: 4.2f}), " f"reward eval: reward: {rewards_eval[-1][1]: 4.2f}, " f"reward normalized={rn :4.2f}/{rs :4.2f}, " f"grad norm={gn: 4.2f}, " f"loss_value={loss_value: 4.2f}, " f"loss_actor={loss_actor: 4.2f}, " f"target value: {target_value: 4.2f}" ) # update the exploration strategy actor_model_explore.step(current_frames) collector.shutdown() del collector .. rst-class:: sphx-glr-script-out .. code-block:: pytb Traceback (most recent call last): File "/pytorch/rl/docs/source/reference/generated/tutorials/coding_ddpg.py", line 1126, in loss_dict = loss_module(sampled_tensordict) File "/pytorch/rl/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/pytorch/rl/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl result = forward_call(*args, **kwargs) File "/pytorch/rl/env/lib/python3.8/site-packages/tensordict/_contextlib.py", line 126, in decorate_context return func(*args, **kwargs) File "/pytorch/rl/docs/source/reference/generated/tutorials/coding_ddpg.py", line 364, in _forward loss_value, td_error, pred_val, target_value = self.loss_value( File "/pytorch/rl/docs/source/reference/generated/tutorials/coding_ddpg.py", line 342, in _loss_value with target_params.to_module(self.value_estimator): File "/pytorch/rl/env/lib/python3.8/site-packages/tensordict/utils.py", line 1189, in new_func out = func(_self, *args, **kwargs) File "/pytorch/rl/env/lib/python3.8/site-packages/tensordict/base.py", line 720, in to_module return self._to_module( File "/pytorch/rl/env/lib/python3.8/site-packages/tensordict/_td.py", line 442, in _to_module child = __dict__["_modules"][key] KeyError: 'module' .. GENERATED FROM PYTHON SOURCE LINES 1183-1192 Experiment results ------------------ We make a simple plot of the average rewards during training. We can observe that our policy learned quite well to solve the task. .. note:: As already mentioned above, to get a more reasonable performance, use a greater value for ``total_frames`` for example, 1M. .. GENERATED FROM PYTHON SOURCE LINES 1192-1203 .. code-block:: Python from matplotlib import pyplot as plt plt.figure() plt.plot(*zip(*rewards), label="training") plt.plot(*zip(*rewards_eval), label="eval") plt.legend() plt.xlabel("iter") plt.ylabel("reward") plt.tight_layout() .. GENERATED FROM PYTHON SOURCE LINES 1204-1225 Conclusion ---------- In this tutorial, we have learned how to code a loss module in TorchRL given the concrete example of DDPG. The key takeaways are: - How to use the :class:`~torchrl.objectives.LossModule` class to code up a new loss component; - How to use (or not) a target network, and how to update its parameters; - How to create an optimizer associated with a loss module. Next Steps ---------- To iterate further on this loss module we might consider: - Using `@dispatch` (see `[Feature] Distpatch IQL loss module `_.) - Allowing flexible TensorDict keys. .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 32.545 seconds) **Estimated memory usage:** 4333 MB .. _sphx_glr_download_tutorials_coding_ddpg.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: coding_ddpg.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: coding_ddpg.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_