
Source code for torchrl.objectives.dreamer

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass

import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from tensordict.utils import NestedKey

from torchrl._utils import timeit
from torchrl.envs.model_based.dreamer import DreamerEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
)  # distance_loss,
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

[docs]class DreamerModelLoss(LossModule): """Dreamer Model Loss. Computes the loss of the dreamer world model. The loss is composed of the kl divergence between the prior and posterior of the RSSM, the reconstruction loss over the reconstructed observation and the reward loss over the predicted reward. Reference: Args: world_model (TensorDictModule): the world model. lambda_kl (:obj:`float`, optional): the weight of the kl divergence loss. Default: 1.0. lambda_reco (:obj:`float`, optional): the weight of the reconstruction loss. Default: 1.0. lambda_reward (:obj:`float`, optional): the weight of the reward loss. Default: 1.0. reco_loss (str, optional): the reconstruction loss. Default: "l2". reward_loss (str, optional): the reward loss. Default: "l2". free_nats (int, optional): the free nats. Default: 3. delayed_clamp (bool, optional): if ``True``, the KL clamping occurs after averaging. If False (default), the kl divergence is clamped to the free nats value first and then averaged. global_average (bool, optional): if ``True``, the losses will be averaged over all dimensions. Otherwise, a sum will be performed over all non-batch/time dimensions and an average over batch and time. Default: False. """ @dataclass class _AcceptedKeys: """Maintains default values for all configurable tensordict keys. This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their default values Attributes: reward (NestedKey): The reward is expected to be in the tensordict key ("next", reward). Defaults to ``"reward"``. true_reward (NestedKey): The `true_reward` will be stored in the tensordict key ("next", true_reward). Defaults to ``"true_reward"``. prior_mean (NestedKey): The prior mean is expected to be in the tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``. prior_std (NestedKey): The prior mean is expected to be in the tensordict key ("next", prior_mean). Defaults to ``"prior_mean"``. posterior_mean (NestedKey): The posterior mean is expected to be in the tensordict key ("next", prior_mean). Defaults to ``"posterior_mean"``. posterior_std (NestedKey): The posterior std is expected to be in the tensordict key ("next", prior_mean). Defaults to ``"posterior_std"``. pixels (NestedKey): The pixels is expected to be in the tensordict key ("next", pixels). Defaults to ``"pixels"``. reco_pixels (NestedKey): The reconstruction pixels is expected to be in the tensordict key ("next", reco_pixels). Defaults to ``"reco_pixels"``. """ reward: NestedKey = "reward" true_reward: NestedKey = "true_reward" prior_mean: NestedKey = "prior_mean" prior_std: NestedKey = "prior_std" posterior_mean: NestedKey = "posterior_mean" posterior_std: NestedKey = "posterior_std" pixels: NestedKey = "pixels" reco_pixels: NestedKey = "reco_pixels" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys decoder: TensorDictModule reward_model: TensorDictModule world_mdel: TensorDictModule def __init__( self, world_model: TensorDictModule, *, lambda_kl: float = 1.0, lambda_reco: float = 1.0, lambda_reward: float = 1.0, reco_loss: str | None = None, reward_loss: str | None = None, free_nats: int = 3, delayed_clamp: bool = False, global_average: bool = False, ): super().__init__() self.world_model = world_model self.reco_loss = reco_loss if reco_loss is not None else "l2" self.reward_loss = reward_loss if reward_loss is not None else "l2" self.lambda_kl = lambda_kl self.lambda_reco = lambda_reco self.lambda_reward = lambda_reward self.free_nats = free_nats self.delayed_clamp = delayed_clamp self.global_average = global_average self.__dict__["decoder"] = self.world_model[0][-1] self.__dict__["reward_model"] = self.world_model[1] def _forward_value_estimator_keys(self, **kwargs) -> None: pass
[docs] def forward(self, tensordict: TensorDict) -> torch.Tensor: tensordict = tensordict.clone(recurse=False) tensordict.rename_key_( ("next", self.tensor_keys.reward), ("next", self.tensor_keys.true_reward), ) tensordict = self.world_model(tensordict) # compute model loss kl_loss = self.kl_loss( tensordict.get(("next", self.tensor_keys.prior_mean)), tensordict.get(("next", self.tensor_keys.prior_std)), tensordict.get(("next", self.tensor_keys.posterior_mean)), tensordict.get(("next", self.tensor_keys.posterior_std)), ).unsqueeze(-1) reco_loss = distance_loss( tensordict.get(("next", self.tensor_keys.pixels)), tensordict.get(("next", self.tensor_keys.reco_pixels)), self.reco_loss, ) if not self.global_average: reco_loss = reco_loss.sum((-3, -2, -1)) reco_loss = reco_loss.mean().unsqueeze(-1) reward_loss = distance_loss( tensordict.get(("next", self.tensor_keys.true_reward)), tensordict.get(("next", self.tensor_keys.reward)), self.reward_loss, ) if not self.global_average: reward_loss = reward_loss.squeeze(-1) reward_loss = reward_loss.mean().unsqueeze(-1) # import ipdb; ipdb.set_trace() td_out = TensorDict( loss_model_kl=self.lambda_kl * kl_loss, loss_model_reco=self.lambda_reco * reco_loss, loss_model_reward=self.lambda_reward * reward_loss, ) self._clear_weakrefs(tensordict, td_out) return ( td_out, tensordict.detach(), )
@staticmethod def normal_log_probability(x, mean, std): return ( -0.5 * (( - mean) / std).pow(2) - std.log() ) # - 0.5 * math.log(2 * math.pi) def kl_loss( self, prior_mean: torch.Tensor, prior_std: torch.Tensor, posterior_mean: torch.Tensor, posterior_std: torch.Tensor, ) -> torch.Tensor: kl = ( torch.log(prior_std / posterior_std) + (posterior_std**2 + (prior_mean - posterior_mean) ** 2) / (2 * prior_std**2) - 0.5 ) if not self.global_average: kl = kl.sum(-1) if self.delayed_clamp: kl = kl.mean().clamp_min(self.free_nats) else: kl = kl.clamp_min(self.free_nats).mean() return kl
[docs]class DreamerActorLoss(LossModule): """Dreamer Actor Loss. Computes the loss of the dreamer actor. The actor loss is computed as the negative average lambda return. Reference: Args: actor_model (TensorDictModule): the actor model. value_model (TensorDictModule): the value model. model_based_env (DreamerEnv): the model based environment. imagination_horizon (int, optional): The number of steps to unroll the model. Defaults to ``15``. discount_loss (bool, optional): if ``True``, the loss is discounted with a gamma discount factor. Default to ``False``. """ @dataclass class _AcceptedKeys: """Maintains default values for all configurable tensordict keys. This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their default values. Attributes: belief (NestedKey): The input tensordict key where the belief is expected. Defaults to ``"belief"``. reward (NestedKey): The reward is expected to be in the tensordict key ("next", reward). Defaults to ``"reward"``. value (NestedKey): The reward is expected to be in the tensordict key ("next", value). Will be used for the underlying value estimator. Defaults to ``"state_value"``. done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected ("next", done). Defaults to ``"done"``. terminated (NestedKey): The input tensordict key where the flag if a trajectory is terminated is expected ("next", terminated). Defaults to ``"terminated"``. """ belief: NestedKey = "belief" reward: NestedKey = "reward" value: NestedKey = "state_value" done: NestedKey = "done" terminated: NestedKey = "terminated" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys default_value_estimator = ValueEstimators.TDLambda value_model: TensorDictModule actor_model: TensorDictModule def __init__( self, actor_model: TensorDictModule, value_model: TensorDictModule, model_based_env: DreamerEnv, *, imagination_horizon: int = 15, discount_loss: bool = True, # for consistency with paper gamma: int = None, lmbda: int = None, ): super().__init__() self.actor_model = actor_model self.__dict__["value_model"] = value_model self.model_based_env = model_based_env self.imagination_horizon = imagination_horizon self.discount_loss = discount_loss if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) if lmbda is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( value=self._tensor_keys.value, )
[docs] def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: tensordict ="state", self.tensor_keys.belief).detach() with timeit("actor_loss/time-rollout"), hold_out_net( self.model_based_env ), set_exploration_type(ExplorationType.RANDOM): tensordict = self.model_based_env.reset(tensordict.copy()) fake_data = self.model_based_env.rollout( max_steps=self.imagination_horizon, policy=self.actor_model, auto_reset=False, tensordict=tensordict, ) next_tensordict = step_mdp(fake_data, keep_other=True) with hold_out_net(self.value_model): next_tensordict = self.value_model(next_tensordict) reward = fake_data.get(("next", self.tensor_keys.reward)) next_value = next_tensordict.get(self.tensor_keys.value) lambda_target = self.lambda_target(reward, next_value) fake_data.set("lambda_target", lambda_target) if self.discount_loss: gamma = discount = gamma.expand(lambda_target.shape).clone() discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) actor_loss = -(lambda_target * discount).sum((-2, -1)).mean() else: actor_loss = -lambda_target.sum((-2, -1)).mean() loss_tensordict = TensorDict({"loss_actor": actor_loss}, []) self._clear_weakrefs(tensordict, loss_tensordict) return loss_tensordict, fake_data.detach()
def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor: done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) terminated = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) input_tensordict = TensorDict( { ("next", self.tensor_keys.reward): reward, ("next", self.tensor_keys.value): value, ("next", self.tensor_keys.done): done, ("next", self.tensor_keys.terminated): terminated, }, [], ) return self.value_estimator.value_estimate(input_tensordict)
[docs] def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: value_type = self.default_value_estimator self.value_type = value_type value_net = None hp = dict(default_value_kwargs(value_type)) if hasattr(self, "gamma"): hp["gamma"] = self.gamma hp.update(hyperparams) if value_type is ValueEstimators.TD1: self._value_estimator = TD1Estimator( **hp, value_network=value_net, ) elif value_type is ValueEstimators.TD0: self._value_estimator = TD0Estimator( **hp, value_network=value_net, ) elif value_type is ValueEstimators.GAE: if hasattr(self, "lmbda"): hp["lmbda"] = self.lmbda raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) elif value_type is ValueEstimators.TDLambda: if hasattr(self, "lmbda"): hp["lmbda"] = self.lmbda self._value_estimator = TDLambdaEstimator( **hp, value_network=value_net, vectorized=True, # TODO: vectorized version seems not to be similar to the non vectorised ) else: raise NotImplementedError(f"Unknown value type {value_type}") tensor_keys = { "value": self.tensor_keys.value, "value_target": "value_target", } self._value_estimator.set_keys(**tensor_keys)
[docs]class DreamerValueLoss(LossModule): """Dreamer Value Loss. Computes the loss of the dreamer value model. The value loss is computed between the predicted value and the lambda target. Reference: Args: value_model (TensorDictModule): the value model. value_loss (str, optional): the loss to use for the value loss. Default: ``"l2"``. discount_loss (bool, optional): if ``True``, the loss is discounted with a gamma discount factor. Default: False. gamma (:obj:`float`, optional): the gamma discount factor. Default: ``0.99``. """ @dataclass class _AcceptedKeys: """Maintains default values for all configurable tensordict keys. This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their default values Attributes: value (NestedKey): The input tensordict key where the state value is expected. Defaults to ``"state_value"``. """ value: NestedKey = "state_value" tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys value_model: TensorDictModule def __init__( self, value_model: TensorDictModule, value_loss: str | None = None, discount_loss: bool = True, # for consistency with paper gamma: int = 0.99, ): super().__init__() self.value_model = value_model self.value_loss = value_loss if value_loss is not None else "l2" self.gamma = gamma self.discount_loss = discount_loss def _forward_value_estimator_keys(self, **kwargs) -> None: pass
[docs] def forward(self, fake_data) -> torch.Tensor: lambda_target = fake_data.get("lambda_target") tensordict_select =*self.value_model.in_keys, strict=False) self.value_model(tensordict_select) if self.discount_loss: discount = self.gamma * torch.ones_like( lambda_target, device=lambda_target.device ) discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) value_loss = ( ( discount * distance_loss( tensordict_select.get(self.tensor_keys.value), lambda_target, self.value_loss, ) ) .sum((-1, -2)) .mean() ) else: value_loss = ( distance_loss( tensordict_select.get(self.tensor_keys.value), lambda_target, self.value_loss, ) .sum((-1, -2)) .mean() ) loss_tensordict = TensorDict({"loss_value": value_loss}) self._clear_weakrefs(fake_data, loss_tensordict) return loss_tensordict, fake_data


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources