Shortcuts

Source code for torchrl.modules.models.models

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import dataclasses

import warnings
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Type, Union

import torch
from tensordict.nn import dispatch, TensorDictModuleBase
from torch import nn
from torch.nn import functional as F

from torchrl._utils import prod
from torchrl.data.utils import DEVICE_TYPING
from torchrl.modules.models.decision_transformer import DecisionTransformer
from torchrl.modules.models.utils import (
    _find_depth,
    create_on_device,
    LazyMapping,
    SquashDims,
    Squeeze2dLayer,
    SqueezeLayer,
)


[docs]class MLP(nn.Sequential): """A multi-layer perceptron. If MLP receives more than one input, it concatenates them all along the last dimension before passing the resulting tensor through the network. This is aimed at allowing for a seamless interface with calls of the type of >>> model(state, action) # compute state-action value In the future, this feature may be moved to the ProbabilisticTDModule, though it would require it to handle different cases (vectors, images, ...) Args: in_features (int, optional): number of input features; out_features (int, list of int): number of output features. If iterable of integers, the output is reshaped to the desired shape; depth (int, optional): depth of the network. A depth of 0 will produce a single linear layer network with the desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, the depth information should be contained in the num_cells argument (see below). If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells. default: 32; activation_class (Type[nn.Module]): activation class to be used. default: nn.Tanh activation_kwargs (dict, optional): kwargs to be used with the activation class; norm_class (Type, optional): normalization class, if any. norm_kwargs (dict, optional): kwargs to be used with the normalization layers; dropout (float, optional): dropout probability. Defaults to ``None`` (no dropout); bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. default: True; single_bias_last_layer (bool): if ``True``, the last dimension of the bias of the last layer will be a singleton dimension. default: True; layer_class (Type[nn.Module]): class to be used for the linear layers; layer_kwargs (dict, optional): kwargs for the linear layers; activate_last_layer (bool): whether the MLP output should be activated. This is useful when the MLP output is used as the input for another module. default: False. device (Optional[DEVICE_TYPING]): device to create the module on. Examples: >>> # All of the following examples provide valid, working MLPs >>> mlp = MLP(in_features=3, out_features=6, depth=0) # MLP consisting of a single 3 x 6 linear layer >>> print(mlp) MLP( (0): Linear(in_features=3, out_features=6, bias=True) ) >>> mlp = MLP(in_features=3, out_features=6, depth=4, num_cells=32) >>> print(mlp) MLP( (0): Linear(in_features=3, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=32, bias=True) (5): Tanh() (6): Linear(in_features=32, out_features=32, bias=True) (7): Tanh() (8): Linear(in_features=32, out_features=6, bias=True) ) >>> mlp = MLP(out_features=6, depth=4, num_cells=32) # LazyLinear for the first layer >>> print(mlp) MLP( (0): LazyLinear(in_features=0, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=32, bias=True) (5): Tanh() (6): Linear(in_features=32, out_features=32, bias=True) (7): Tanh() (8): Linear(in_features=32, out_features=6, bias=True) ) >>> mlp = MLP(out_features=6, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg >>> print(mlp) MLP( (0): LazyLinear(in_features=0, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=33, bias=True) (3): Tanh() (4): Linear(in_features=33, out_features=34, bias=True) (5): Tanh() (6): Linear(in_features=34, out_features=35, bias=True) (7): Tanh() (8): Linear(in_features=35, out_features=6, bias=True) ) >>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35]) # returns a view of the output tensor with shape [*, 6, 7] >>> print(mlp) MLP( (0): LazyLinear(in_features=0, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=33, bias=True) (3): Tanh() (4): Linear(in_features=33, out_features=34, bias=True) (5): Tanh() (6): Linear(in_features=34, out_features=35, bias=True) (7): Tanh() (8): Linear(in_features=35, out_features=42, bias=True) ) >>> from torchrl.modules import NoisyLinear >>> mlp = MLP(out_features=(6, 7), num_cells=[32, 33, 34, 35], layer_class=NoisyLinear) # uses NoisyLinear layers >>> print(mlp) MLP( (0): NoisyLazyLinear(in_features=0, out_features=32, bias=False) (1): Tanh() (2): NoisyLinear(in_features=32, out_features=33, bias=True) (3): Tanh() (4): NoisyLinear(in_features=33, out_features=34, bias=True) (5): Tanh() (6): NoisyLinear(in_features=34, out_features=35, bias=True) (7): Tanh() (8): NoisyLinear(in_features=35, out_features=42, bias=True) ) """ def __init__( self, in_features: Optional[int] = None, out_features: Union[int, Sequence[int]] = None, depth: Optional[int] = None, num_cells: Optional[Union[Sequence, int]] = None, activation_class: Type[nn.Module] = nn.Tanh, activation_kwargs: Optional[dict] = None, norm_class: Optional[Type[nn.Module]] = None, norm_kwargs: Optional[dict] = None, dropout: Optional[float] = None, bias_last_layer: bool = True, single_bias_last_layer: bool = False, layer_class: Type[nn.Module] = nn.Linear, layer_kwargs: Optional[dict] = None, activate_last_layer: bool = False, device: Optional[DEVICE_TYPING] = None, ): if out_features is None: raise ValueError("out_features must be specified for MLP.") default_num_cells = 32 if num_cells is None: if depth is None: num_cells = [default_num_cells] * 3 depth = 3 else: num_cells = [default_num_cells] * depth self.in_features = in_features _out_features_num = out_features if not isinstance(out_features, Number): _out_features_num = prod(out_features) self.out_features = out_features self._out_features_num = _out_features_num self.activation_class = activation_class self.activation_kwargs = ( activation_kwargs if activation_kwargs is not None else {} ) self.norm_class = norm_class self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.dropout = dropout self.bias_last_layer = bias_last_layer self.single_bias_last_layer = single_bias_last_layer self.layer_class = layer_class self.layer_kwargs = layer_kwargs if layer_kwargs is not None else {} self.activate_last_layer = activate_last_layer if single_bias_last_layer: raise NotImplementedError if not (isinstance(num_cells, Sequence) or depth is not None): raise RuntimeError( "If num_cells is provided as an integer, \ depth must be provided too." ) self.num_cells = ( list(num_cells) if isinstance(num_cells, Sequence) else [num_cells] * depth ) self.depth = depth if depth is not None else len(self.num_cells) if not (len(self.num_cells) == depth or depth is None): raise RuntimeError( "depth and num_cells length conflict, \ consider matching or specifying a constant num_cells argument together with a a desired depth" ) layers = self._make_net(device) super().__init__(*layers) def _make_net(self, device: Optional[DEVICE_TYPING]) -> List[nn.Module]: layers = [] in_features = [self.in_features] + self.num_cells out_features = self.num_cells + [self._out_features_num] for i, (_in, _out) in enumerate(zip(in_features, out_features)): _bias = self.bias_last_layer if i == self.depth else True if _in is not None: layers.append( create_on_device( self.layer_class, device, _in, _out, bias=_bias, **self.layer_kwargs, ) ) else: try: lazy_version = LazyMapping[self.layer_class] except KeyError: raise KeyError( f"The lazy version of {self.layer_class.__name__} is not implemented yet. " "Consider providing the input feature dimensions explicitely when creating an MLP module" ) layers.append( create_on_device( lazy_version, device, _out, bias=_bias, **self.layer_kwargs ) ) if i < self.depth or self.activate_last_layer: if self.dropout is not None: layers.append(create_on_device(nn.Dropout, device, p=self.dropout)) if self.norm_class is not None: layers.append( create_on_device(self.norm_class, device, **self.norm_kwargs) ) layers.append( create_on_device( self.activation_class, device, **self.activation_kwargs ) ) return layers
[docs] def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: inputs = (torch.cat([*inputs], -1),) out = super().forward(*inputs) if not isinstance(self.out_features, Number): out = out.view(*out.shape[:-1], *self.out_features) return out
[docs]class ConvNet(nn.Sequential): """A convolutional neural network. Args: in_features (int, optional): number of input features; depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the desired input size, and with an output size equal to the last element of the num_cells argument. If no depth is indicated, the depth information should be contained in the num_cells argument (see below). If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to the depth. num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells. default: [32, 32, 32]; kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the depth, defined by the num_cells or depth arguments. strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the depth, defined by the num_cells or depth arguments. activation_class (Type[nn.Module]): activation class to be used. default: nn.Tanh activation_kwargs (dict, optional): kwargs to be used with the activation class; norm_class (Type, optional): normalization class, if any; norm_kwargs (dict, optional): kwargs to be used with the normalization layers; bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. default: True; aggregator_class (Type[nn.Module]): aggregator to use at the end of the chain. default: SquashDims; aggregator_kwargs (dict, optional): kwargs for the aggregator_class; squeeze_output (bool): whether the output should be squeezed of its singleton dimensions. default: False. device (Optional[DEVICE_TYPING]): device to create the module on. Examples: >>> # All of the following examples provide valid, working MLPs >>> cnet = ConvNet(in_features=3, depth=1, num_cells=[32,]) # MLP consisting of a single 3 x 6 linear layer >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): SquashDims() ) >>> cnet = ConvNet(in_features=3, depth=4, num_cells=32) >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) (3): ELU(alpha=1.0) (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) (5): ELU(alpha=1.0) (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): Conv2d(32, 33, kernel_size=(3, 3), stride=(1, 1)) (3): ELU(alpha=1.0) (4): Conv2d(33, 34, kernel_size=(3, 3), stride=(1, 1)) (5): ELU(alpha=1.0) (6): Conv2d(34, 35, kernel_size=(3, 3), stride=(1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) >>> cnet = ConvNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3)]) # defines kernels, possibly rectangular >>> print(cnet) ConvNet( (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1)) (1): ELU(alpha=1.0) (2): Conv2d(32, 33, kernel_size=(4, 4), stride=(1, 1)) (3): ELU(alpha=1.0) (4): Conv2d(33, 34, kernel_size=(5, 5), stride=(1, 1)) (5): ELU(alpha=1.0) (6): Conv2d(34, 35, kernel_size=(2, 3), stride=(1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) """ def __init__( self, in_features: Optional[int] = None, depth: Optional[int] = None, num_cells: Union[Sequence, int] = None, kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3, strides: Union[Sequence, int] = 1, paddings: Union[Sequence, int] = 0, activation_class: Type[nn.Module] = nn.ELU, activation_kwargs: Optional[dict] = None, norm_class: Optional[Type[nn.Module]] = None, norm_kwargs: Optional[dict] = None, bias_last_layer: bool = True, aggregator_class: Optional[Type[nn.Module]] = SquashDims, aggregator_kwargs: Optional[dict] = None, squeeze_output: bool = False, device: Optional[DEVICE_TYPING] = None, ): if num_cells is None: num_cells = [32, 32, 32] self.in_features = in_features self.activation_class = activation_class self.activation_kwargs = ( activation_kwargs if activation_kwargs is not None else {} ) self.norm_class = norm_class self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.bias_last_layer = bias_last_layer self.aggregator_class = aggregator_class self.aggregator_kwargs = ( aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 3} ) self.squeeze_output = squeeze_output # self.single_bias_last_layer = single_bias_last_layer depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings) self.depth = depth if depth == 0: raise ValueError("Null depth is not permitted with ConvNet.") for _field, _value in zip( ["num_cells", "kernel_sizes", "strides", "paddings"], [num_cells, kernel_sizes, strides, paddings], ): _depth = depth setattr( self, _field, (_value if isinstance(_value, Sequence) else [_value] * _depth), ) if not (isinstance(_value, Sequence) or _depth is not None): raise RuntimeError( f"If {_field} is provided as an integer, " "depth must be provided too." ) if not (len(getattr(self, _field)) == _depth or _depth is None): raise RuntimeError( f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, " + f"consider matching or specifying a constant {_field} argument together with a a desired depth" ) self.out_features = self.num_cells[-1] self.depth = len(self.kernel_sizes) layers = self._make_net(device) super().__init__(*layers) def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: layers = [] in_features = [self.in_features] + self.num_cells[: self.depth] out_features = self.num_cells + [self.out_features] kernel_sizes = self.kernel_sizes strides = self.strides paddings = self.paddings for i, (_in, _out, _kernel, _stride, _padding) in enumerate( zip(in_features, out_features, kernel_sizes, strides, paddings) ): _bias = (i < len(in_features) - 1) or self.bias_last_layer if _in is not None: layers.append( nn.Conv2d( _in, _out, kernel_size=_kernel, stride=_stride, bias=_bias, padding=_padding, device=device, ) ) else: layers.append( nn.LazyConv2d( _out, kernel_size=_kernel, stride=_stride, bias=_bias, padding=_padding, device=device, ) ) layers.append( create_on_device( self.activation_class, device, **self.activation_kwargs ) ) if self.norm_class is not None: layers.append( create_on_device(self.norm_class, device, **self.norm_kwargs) ) if self.aggregator_class is not None: layers.append( create_on_device( self.aggregator_class, device, **self.aggregator_kwargs ) ) if self.squeeze_output: layers.append(Squeeze2dLayer()) return layers
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: *batch, C, L, W = inputs.shape if len(batch) > 1: inputs = inputs.flatten(0, len(batch) - 1) out = super(ConvNet, self).forward(inputs) if len(batch) > 1: out = out.unflatten(0, batch) return out
Conv2dNet = ConvNet
[docs]class Conv3dNet(nn.Sequential): """A 3D-convolutional neural network. Args: in_features (int, optional): number of input features. A lazy implementation that automatically retrieves the input size will be used if none is provided. depth (int, optional): depth of the network. A depth of 1 will produce a single linear layer network with the desired input size, and with an output size equal to the last element of the num_cells argument. If no depth is indicated, the depth information should be contained in the num_cells argument (see below). If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to the depth. num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells. default: ``[32, 32, 32]`` or ``[32] * depth` is depth is not ``None``. kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the conv network. If iterable, the length must match the depth, defined by the num_cells or depth arguments. strides (int or Sequence[int]): Stride(s) of the conv network. If iterable, the length must match the depth, defined by the num_cells or depth arguments. activation_class (Type[nn.Module]): activation class to be used. default: nn.Tanh activation_kwargs (dict, optional): kwargs to be used with the activation class; norm_class (Type, optional): normalization class, if any; norm_kwargs (dict, optional): kwargs to be used with the normalization layers; bias_last_layer (bool): if ``True``, the last Linear layer will have a bias parameter. default: True; aggregator_class (Type[nn.Module]): aggregator to use at the end of the chain. default: SquashDims; aggregator_kwargs (dict, optional): kwargs for the aggregator_class; squeeze_output (bool): whether the output should be squeezed of its singleton dimensions. default: False. device (Optional[DEVICE_TYPING]): device to create the module on. Examples: >>> # All of the following examples provide valid, working MLPs >>> cnet = Conv3dNet(in_features=3, depth=1, num_cells=[32,]) >>> print(cnet) Conv3dNet( (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (1): ELU(alpha=1.0) (2): SquashDims() ) >>> cnet = Conv3dNet(in_features=3, depth=4, num_cells=32) >>> print(cnet) Conv3dNet( (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (1): ELU(alpha=1.0) (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (3): ELU(alpha=1.0) (4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (5): ELU(alpha=1.0) (6): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) >>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35]) # defines the depth by the num_cells arg >>> print(cnet) Conv3dNet( (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (1): ELU(alpha=1.0) (2): Conv3d(32, 33, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (3): ELU(alpha=1.0) (4): Conv3d(33, 34, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (5): ELU(alpha=1.0) (6): Conv3d(34, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) >>> cnet = Conv3dNet(in_features=3, num_cells=[32, 33, 34, 35], kernel_sizes=[3, 4, 5, (2, 3, 4)]) # defines kernels, possibly rectangular >>> print(cnet) Conv3dNet( (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1)) (1): ELU(alpha=1.0) (2): Conv3d(32, 33, kernel_size=(4, 4, 4), stride=(1, 1, 1)) (3): ELU(alpha=1.0) (4): Conv3d(33, 34, kernel_size=(5, 5, 5), stride=(1, 1, 1)) (5): ELU(alpha=1.0) (6): Conv3d(34, 35, kernel_size=(2, 3, 4), stride=(1, 1, 1)) (7): ELU(alpha=1.0) (8): SquashDims() ) """ def __init__( self, in_features: Optional[int] = None, depth: Optional[int] = None, num_cells: Union[Sequence, int] = None, kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 3, strides: Union[Sequence, int] = 1, paddings: Union[Sequence, int] = 0, activation_class: Type[nn.Module] = nn.ELU, activation_kwargs: Optional[dict] = None, norm_class: Optional[Type[nn.Module]] = None, norm_kwargs: Optional[dict] = None, bias_last_layer: bool = True, aggregator_class: Optional[Type[nn.Module]] = SquashDims, aggregator_kwargs: Optional[dict] = None, squeeze_output: bool = False, device: Optional[DEVICE_TYPING] = None, ): if num_cells is None: if depth is None: num_cells = [32, 32, 32] else: num_cells = [32] * depth self.in_features = in_features self.activation_class = activation_class self.activation_kwargs = ( activation_kwargs if activation_kwargs is not None else {} ) self.norm_class = norm_class self.norm_kwargs = norm_kwargs if norm_kwargs is not None else {} self.bias_last_layer = bias_last_layer self.aggregator_class = aggregator_class self.aggregator_kwargs = ( aggregator_kwargs if aggregator_kwargs is not None else {"ndims_in": 4} ) self.squeeze_output = squeeze_output # self.single_bias_last_layer = single_bias_last_layer depth = _find_depth(depth, num_cells, kernel_sizes, strides, paddings) self.depth = depth if depth == 0: raise ValueError("Null depth is not permitted with Conv3dNet.") for _field, _value in zip( ["num_cells", "kernel_sizes", "strides", "paddings"], [num_cells, kernel_sizes, strides, paddings], ): _depth = depth setattr( self, _field, (_value if isinstance(_value, Sequence) else [_value] * _depth), ) if not (len(getattr(self, _field)) == _depth or _depth is None): raise ValueError( f"depth={depth} and {_field}={len(getattr(self, _field))} length conflict, " + f"consider matching or specifying a constant {_field} argument together with a a desired depth" ) self.out_features = self.num_cells[-1] self.depth = len(self.kernel_sizes) layers = self._make_net(device) super().__init__(*layers) def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: layers = [] in_features = [self.in_features] + self.num_cells[: self.depth] out_features = self.num_cells + [self.out_features] kernel_sizes = self.kernel_sizes strides = self.strides paddings = self.paddings for i, (_in, _out, _kernel, _stride, _padding) in enumerate( zip(in_features, out_features, kernel_sizes, strides, paddings) ): _bias = (i < len(in_features) - 1) or self.bias_last_layer if _in is not None: layers.append( nn.Conv3d( _in, _out, kernel_size=_kernel, stride=_stride, bias=_bias, padding=_padding, device=device, ) ) else: layers.append( nn.LazyConv3d( _out, kernel_size=_kernel, stride=_stride, bias=_bias, padding=_padding, device=device, ) ) layers.append( create_on_device( self.activation_class, device, **self.activation_kwargs ) ) if self.norm_class is not None: layers.append( create_on_device(self.norm_class, device, **self.norm_kwargs) ) if self.aggregator_class is not None: layers.append( create_on_device( self.aggregator_class, device, **self.aggregator_kwargs ) ) if self.squeeze_output: layers.append(SqueezeLayer((-3, -2, -1))) return layers
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: try: *batch, C, D, L, W = inputs.shape except ValueError as err: raise ValueError( f"The input value of {self.__class__.__name__} must have at least 4 dimensions, got {inputs.ndim} instead." ) from err if len(batch) > 1: inputs = inputs.flatten(0, len(batch) - 1) out = super().forward(inputs) if len(batch) > 1: out = out.unflatten(0, batch) return out
class DuelingMlpDQNet(nn.Module): """Creates a Dueling MLP Q-network. Presented in https://arxiv.org/abs/1511.06581 Args: out_features (int): number of features for the advantage network out_features_value (int): number of features for the value network mlp_kwargs_feature (dict, optional): kwargs for the feature network. Default is >>> mlp_kwargs_feature = { ... 'num_cells': [256, 256], ... 'activation_class': nn.ELU, ... 'out_features': 256, ... 'activate_last_layer': True, ... } mlp_kwargs_output (dict, optional): kwargs for the advantage and value networks. Default is >>> mlp_kwargs_output = { ... "depth": 1, ... "activation_class": nn.ELU, ... "num_cells": 512, ... "bias_last_layer": True, ... } device (Optional[DEVICE_TYPING]): device to create the module on. """ def __init__( self, out_features: int, out_features_value: int = 1, mlp_kwargs_feature: Optional[dict] = None, mlp_kwargs_output: Optional[dict] = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() mlp_kwargs_feature = ( mlp_kwargs_feature if mlp_kwargs_feature is not None else {} ) _mlp_kwargs_feature = { "num_cells": [256, 256], "out_features": 256, "activation_class": nn.ELU, "activate_last_layer": True, } _mlp_kwargs_feature.update(mlp_kwargs_feature) self.features = MLP(device=device, **_mlp_kwargs_feature) _mlp_kwargs_output = { "depth": 1, "activation_class": nn.ELU, "num_cells": 512, "bias_last_layer": True, } mlp_kwargs_output = mlp_kwargs_output if mlp_kwargs_output is not None else {} _mlp_kwargs_output.update(mlp_kwargs_output) self.out_features = out_features self.out_features_value = out_features_value self.advantage = MLP( out_features=out_features, device=device, **_mlp_kwargs_output ) self.value = MLP( out_features=out_features_value, device=device, **_mlp_kwargs_output ) for layer in self.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance( layer.bias, torch.Tensor ): layer.bias.data.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) advantage = self.advantage(x) value = self.value(x) return value + advantage - advantage.mean(dim=-1, keepdim=True)
[docs]class DuelingCnnDQNet(nn.Module): """Dueling CNN Q-network. Presented in https://arxiv.org/abs/1511.06581 Args: out_features (int): number of features for the advantage network out_features_value (int): number of features for the value network cnn_kwargs (dict, optional): kwargs for the feature network. Default is >>> cnn_kwargs = { ... 'num_cells': [32, 64, 64], ... 'strides': [4, 2, 1], ... 'kernels': [8, 4, 3], ... } mlp_kwargs (dict, optional): kwargs for the advantage and value network. Default is >>> mlp_kwargs = { ... "depth": 1, ... "activation_class": nn.ELU, ... "num_cells": 512, ... "bias_last_layer": True, ... } device (Optional[DEVICE_TYPING]): device to create the module on. """ def __init__( self, out_features: int, out_features_value: int = 1, cnn_kwargs: Optional[dict] = None, mlp_kwargs: Optional[dict] = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} _cnn_kwargs = { "num_cells": [32, 64, 64], "strides": [4, 2, 1], "kernel_sizes": [8, 4, 3], } _cnn_kwargs.update(cnn_kwargs) self.features = ConvNet(device=device, **_cnn_kwargs) _mlp_kwargs = { "depth": 1, "activation_class": nn.ELU, "num_cells": 512, "bias_last_layer": True, } mlp_kwargs = mlp_kwargs if mlp_kwargs is not None else {} _mlp_kwargs.update(mlp_kwargs) self.out_features = out_features self.out_features_value = out_features_value self.advantage = MLP(out_features=out_features, device=device, **_mlp_kwargs) self.value = MLP(out_features=out_features_value, device=device, **_mlp_kwargs) for layer in self.modules(): if isinstance(layer, (nn.Conv2d, nn.Linear)) and isinstance( layer.bias, torch.Tensor ): layer.bias.data.zero_()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) advantage = self.advantage(x) value = self.value(x) return value + advantage - advantage.mean(dim=-1, keepdim=True)
[docs]class DistributionalDQNnet(TensorDictModuleBase): """Distributional Deep Q-Network. Args: DQNet (nn.Module): (deprecated) Q-Network with output length equal to the number of atoms: output.shape = [*batch, atoms, actions]. in_keys (list of str or tuples of str): input keys to the log-softmax operation. Defaults to ``["action_value"]``. out_keys (list of str or tuples of str): output keys to the log-softmax operation. Defaults to ``["action_value"]``. """ _wrong_out_feature_dims_error = ( "DistributionalDQNnet requires dqn output to be at least " "2-dimensional, with dimensions *Batch x #Atoms x #Actions. Got {0} " "instead." ) def __init__(self, DQNet: nn.Module = None, in_keys=None, out_keys=None): super().__init__() if DQNet is not None: warnings.warn( f"Passing a network to {type(self)} is going to be deprecated.", category=DeprecationWarning, ) if not ( not isinstance(DQNet.out_features, Number) and len(DQNet.out_features) > 1 ): raise RuntimeError(self._wrong_out_feature_dims_error) self.dqn = DQNet if in_keys is None: in_keys = ["action_value"] if out_keys is None: out_keys = ["action_value"] self.in_keys = in_keys self.out_keys = out_keys
[docs] @dispatch(auto_batch_size=False) def forward(self, tensordict): for in_key, out_key in zip(self.in_keys, self.out_keys): q_values = tensordict.get(in_key) if self.dqn is not None: q_values = self.dqn(q_values) if q_values.ndimension() < 2: raise RuntimeError( self._wrong_out_feature_dims_error.format(q_values.shape) ) tensordict.set(out_key, F.log_softmax(q_values, dim=-2)) return tensordict
def ddpg_init_last_layer( module: nn.Sequential, scale: float = 6e-4, device: Optional[DEVICE_TYPING] = None, ) -> None: """Initializer for the last layer of DDPG. Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf """ for last_layer in reversed(module): if isinstance(last_layer, (nn.Linear, nn.Conv2d)): break else: raise RuntimeError("Could not find a nn.Linear / nn.Conv2d to initialize.") last_layer.weight.data.copy_( torch.rand_like(last_layer.weight.data, device=device) * scale - scale / 2 ) if last_layer.bias is not None: last_layer.bias.data.copy_( torch.rand_like(last_layer.bias.data, device=device) * scale - scale / 2 )
[docs]class DdpgCnnActor(nn.Module): """DDPG Convolutional Actor class. Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Convolutional Actor takes as input an observation (some simple transformation of the observed pixels) and returns an action vector from it. It is trained to maximise the value returned by the DDPG Q Value network. Args: action_dim (int): length of the action vector. conv_net_kwargs (dict, optional): kwargs for the ConvNet. default: { 'in_features': None, "num_cells": [32, 64, 64], "kernel_sizes": [8, 4, 3], "strides": [4, 2, 1], "paddings": [0, 0, 1], 'activation_class': nn.ELU, 'norm_class': None, 'aggregator_class': SquashDims, 'aggregator_kwargs': {"ndims_in": 3}, 'squeeze_output': True, } mlp_net_kwargs: kwargs for MLP. Default: { 'in_features': None, 'out_features': action_dim, 'depth': 2, 'num_cells': 200, 'activation_class': nn.ELU, 'bias_last_layer': True, } use_avg_pooling (bool, optional): if ``True``, a nn.AvgPooling layer is used to aggregate the output. Default is ``False``. device (Optional[DEVICE_TYPING]): device to create the module on. """ def __init__( self, action_dim: int, conv_net_kwargs: Optional[dict] = None, mlp_net_kwargs: Optional[dict] = None, use_avg_pooling: bool = False, device: Optional[DEVICE_TYPING] = None, ): super().__init__() conv_net_default_kwargs = { "in_features": None, "num_cells": [32, 64, 64], "kernel_sizes": [8, 4, 3], "strides": [4, 2, 1], "paddings": [0, 0, 1], "activation_class": nn.ELU, "norm_class": None, "aggregator_class": SquashDims if not use_avg_pooling else nn.AdaptiveAvgPool2d, "aggregator_kwargs": {"ndims_in": 3} if not use_avg_pooling else {"output_size": (1, 1)}, "squeeze_output": use_avg_pooling, } conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else {} conv_net_default_kwargs.update(conv_net_kwargs) mlp_net_default_kwargs = { "in_features": None, "out_features": action_dim, "depth": 2, "num_cells": 200, "activation_class": nn.ELU, "bias_last_layer": True, } mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} mlp_net_default_kwargs.update(mlp_net_kwargs) self.convnet = ConvNet(device=device, **conv_net_default_kwargs) self.mlp = MLP(device=device, **mlp_net_default_kwargs) ddpg_init_last_layer(self.mlp, 6e-4, device=device)
[docs] def forward(self, observation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: hidden = self.convnet(observation) action = self.mlp(hidden) return action, hidden
[docs]class DdpgMlpActor(nn.Module): """DDPG Actor class. Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Actor takes as input an observation vector and returns an action from it. It is trained to maximise the value returned by the DDPG Q Value network. Args: action_dim (int): length of the action vector mlp_net_kwargs (dict, optional): kwargs for MLP. Default: { 'in_features': None, 'out_features': action_dim, 'depth': 2, 'num_cells': [400, 300], 'activation_class': nn.ELU, 'bias_last_layer': True, } device (Optional[DEVICE_TYPING]): device to create the module on. """ def __init__( self, action_dim: int, mlp_net_kwargs: Optional[dict] = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() mlp_net_default_kwargs = { "in_features": None, "out_features": action_dim, "depth": 2, "num_cells": [400, 300], "activation_class": nn.ELU, "bias_last_layer": True, } mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} mlp_net_default_kwargs.update(mlp_net_kwargs) self.mlp = MLP(device=device, **mlp_net_default_kwargs) ddpg_init_last_layer(self.mlp, 6e-3, device=device)
[docs] def forward(self, observation: torch.Tensor) -> torch.Tensor: action = self.mlp(observation) return action
[docs]class DdpgCnnQNet(nn.Module): """DDPG Convolutional Q-value class. Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. Args: conv_net_kwargs (dict, optional): kwargs for the convolutional network. default: { 'in_features': None, "num_cells": [32, 64, 128], "kernel_sizes": [8, 4, 3], "strides": [4, 2, 1], "paddings": [0, 0, 1], 'activation_class': nn.ELU, 'norm_class': None, 'aggregator_class': nn.AdaptiveAvgPool2d, 'aggregator_kwargs': {}, 'squeeze_output': True, } mlp_net_kwargs (dict, optional): kwargs for MLP. Default: { 'in_features': None, 'out_features': 1, 'depth': 2, 'num_cells': 200, 'activation_class': nn.ELU, 'bias_last_layer': True, } use_avg_pooling (bool, optional): if ``True``, a nn.AvgPooling layer is used to aggregate the output. Default is ``True``. device (Optional[DEVICE_TYPING]): device to create the module on. """ def __init__( self, conv_net_kwargs: Optional[dict] = None, mlp_net_kwargs: Optional[dict] = None, use_avg_pooling: bool = True, device: Optional[DEVICE_TYPING] = None, ): super().__init__() conv_net_default_kwargs = { "in_features": None, "num_cells": [32, 64, 128], "kernel_sizes": [8, 4, 3], "strides": [4, 2, 1], "paddings": [0, 0, 1], "activation_class": nn.ELU, "norm_class": None, "aggregator_class": SquashDims if not use_avg_pooling else nn.AdaptiveAvgPool2d, "aggregator_kwargs": {"ndims_in": 3} if not use_avg_pooling else {"output_size": (1, 1)}, "squeeze_output": use_avg_pooling, } conv_net_kwargs = conv_net_kwargs if conv_net_kwargs is not None else {} conv_net_default_kwargs.update(conv_net_kwargs) mlp_net_default_kwargs = { "in_features": None, "out_features": 1, "depth": 2, "num_cells": 200, "activation_class": nn.ELU, "bias_last_layer": True, } mlp_net_kwargs = mlp_net_kwargs if mlp_net_kwargs is not None else {} mlp_net_default_kwargs.update(mlp_net_kwargs) self.convnet = ConvNet(device=device, **conv_net_default_kwargs) self.mlp = MLP(device=device, **mlp_net_default_kwargs) ddpg_init_last_layer(self.mlp, 6e-4, device=device)
[docs] def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tensor: hidden = torch.cat([self.convnet(observation), action], -1) value = self.mlp(hidden) return value
[docs]class DdpgMlpQNet(nn.Module): """DDPG Q-value MLP class. Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf The DDPG Q-value network takes as input an observation and an action, and returns a scalar from it. Because actions are integrated later than observations, two networks are created. Args: mlp_net_kwargs_net1 (dict, optional): kwargs for MLP. Default: { 'in_features': None, 'out_features': 400, 'depth': 0, 'num_cells': [], 'activation_class': nn.ELU, 'bias_last_layer': True, 'activate_last_layer': True, } mlp_net_kwargs_net2 Default: { 'in_features': None, 'out_features': 1, 'depth': 1, 'num_cells': [300, ], 'activation_class': nn.ELU, 'bias_last_layer': True, } device (Optional[DEVICE_TYPING]): device to create the module on. """ def __init__( self, mlp_net_kwargs_net1: Optional[dict] = None, mlp_net_kwargs_net2: Optional[dict] = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() mlp1_net_default_kwargs = { "in_features": None, "out_features": 400, "depth": 0, "num_cells": [], "activation_class": nn.ELU, "bias_last_layer": True, "activate_last_layer": True, } mlp_net_kwargs_net1: Dict = ( mlp_net_kwargs_net1 if mlp_net_kwargs_net1 is not None else {} ) mlp1_net_default_kwargs.update(mlp_net_kwargs_net1) self.mlp1 = MLP(device=device, **mlp1_net_default_kwargs) mlp2_net_default_kwargs = { "in_features": None, "out_features": 1, "num_cells": [ 300, ], "activation_class": nn.ELU, "bias_last_layer": True, } mlp_net_kwargs_net2 = ( mlp_net_kwargs_net2 if mlp_net_kwargs_net2 is not None else {} ) mlp2_net_default_kwargs.update(mlp_net_kwargs_net2) self.mlp2 = MLP(device=device, **mlp2_net_default_kwargs) ddpg_init_last_layer(self.mlp2, 6e-3, device=device)
[docs] def forward(self, observation: torch.Tensor, action: torch.Tensor) -> torch.Tensor: value = self.mlp2(torch.cat([self.mlp1(observation), action], -1)) return value
[docs]class LSTMNet(nn.Module): """An embedder for an LSTM preceded by an MLP. The forward method returns the hidden states of the current state (input hidden states) and the output, as the environment returns the 'observation' and 'next_observation'. Because the LSTM kernel only returns the last hidden state, hidden states are padded with zeros such that they have the right size to be stored in a TensorDict of size [batch x time_steps]. If a 2D tensor is provided as input, it is assumed that it is a batch of data with only one time step. This means that we explicitely assume that users will unsqueeze inputs of a single batch with multiple time steps. Examples: >>> batch = 7 >>> time_steps = 6 >>> in_features = 4 >>> out_features = 10 >>> hidden_size = 5 >>> net = LSTMNet( ... out_features, ... {"input_size": hidden_size, "hidden_size": hidden_size}, ... {"out_features": hidden_size}, ... ) >>> # test single step vs multi-step >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) >>> x = torch.randn(batch, in_features) # 2 dims = single step >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) """ def __init__( self, out_features: int, lstm_kwargs: Dict, mlp_kwargs: Dict, device: Optional[DEVICE_TYPING] = None, ) -> None: warnings.warn( "LSTMNet is being deprecated in favour of torchrl.modules.LSTMModule, and will be removed soon.", category=DeprecationWarning, ) super().__init__() lstm_kwargs.update({"batch_first": True}) self.mlp = MLP(device=device, **mlp_kwargs) self.lstm = nn.LSTM(device=device, **lstm_kwargs) self.linear = nn.LazyLinear(out_features, device=device) def _lstm( self, input: torch.Tensor, hidden0_in: Optional[torch.Tensor] = None, hidden1_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: squeeze0 = False squeeze1 = False if input.ndimension() == 1: squeeze0 = True input = input.unsqueeze(0).contiguous() if input.ndimension() == 2: squeeze1 = True input = input.unsqueeze(1).contiguous() batch, steps = input.shape[:2] if hidden1_in is None and hidden0_in is None: shape = (batch, steps) if not squeeze1 else (batch,) hidden0_in, hidden1_in = [ torch.zeros( *shape, self.lstm.num_layers, self.lstm.hidden_size, device=input.device, dtype=input.dtype, ) for _ in range(2) ] elif hidden1_in is None or hidden0_in is None: raise RuntimeError( f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" ) elif squeeze0: hidden0_in = hidden0_in.unsqueeze(0) hidden1_in = hidden1_in.unsqueeze(0) # we only need the first hidden state if not squeeze1: _hidden0_in = hidden0_in[:, 0] _hidden1_in = hidden1_in[:, 0] else: _hidden0_in = hidden0_in _hidden1_in = hidden1_in hidden = ( _hidden0_in.transpose(-3, -2).contiguous(), _hidden1_in.transpose(-3, -2).contiguous(), ) y0, hidden = self.lstm(input, hidden) # dim 0 in hidden is num_layers, but that will conflict with tensordict hidden = tuple(_h.transpose(0, 1) for _h in hidden) y = self.linear(y0) out = [y, hidden0_in, hidden1_in, *hidden] if squeeze1: # squeezes time out[0] = out[0].squeeze(1) if not squeeze1: # we pad the hidden states with zero to make tensordict happy for i in range(3, 5): out[i] = torch.stack( [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] + [out[i]], 1, ) if squeeze0: out = [_out.squeeze(0) for _out in out] return tuple(out)
[docs] def forward( self, input: torch.Tensor, hidden0_in: Optional[torch.Tensor] = None, hidden1_in: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in)
[docs]class OnlineDTActor(nn.Module): """Online Decision Transformer Actor class. Actor class for the Online Decision Transformer to sample actions from gaussian distribution as presented inresented in `"Online Decision Transformer" <https://arxiv.org/abs/2202.05607.pdf>`. Returns mu and sigma for the gaussian distribution to sample actions from. Args: state_dim (int): state dimension. action_dim (int): action dimension. transformer_config (Dict or :class:`DecisionTransformer.DTConfig`): config for the GPT2 transformer. Defaults to :meth:`~.default_config`. device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. Examples: >>> model = OnlineDTActor(state_dim=4, action_dim=2, ... transformer_config=OnlineDTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) >>> mu, std = model(observation, action, return_to_go) >>> mu.shape torch.Size([32, 10, 2]) >>> std.shape torch.Size([32, 10, 2]) """ def __init__( self, state_dim: int, action_dim: int, transformer_config: Dict | DecisionTransformer.DTConfig = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() if transformer_config is None: transformer_config = self.default_config() if isinstance(transformer_config, DecisionTransformer.DTConfig): transformer_config = dataclasses.asdict(transformer_config) self.transformer = DecisionTransformer( state_dim=state_dim, action_dim=action_dim, config=transformer_config, ) self.action_layer_mean = nn.Linear( transformer_config["n_embd"], action_dim, device=device ) self.action_layer_logstd = nn.Linear( transformer_config["n_embd"], action_dim, device=device ) self.log_std_min, self.log_std_max = -5.0, 2.0 def weight_init(m): """Custom weight init for Conv2D and Linear layers.""" if isinstance(m, torch.nn.Linear): nn.init.orthogonal_(m.weight.data) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) self.action_layer_mean.apply(weight_init) self.action_layer_logstd.apply(weight_init)
[docs] def forward( self, observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: hidden_state = self.transformer(observation, action, return_to_go) mu = self.action_layer_mean(hidden_state) log_std = self.action_layer_logstd(hidden_state) log_std = torch.tanh(log_std) # log_std is the output of tanh so it will be between [-1, 1] # map it to be between [log_std_min, log_std_max] log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * ( log_std + 1.0 ) std = log_std.exp() return mu, std
[docs] @classmethod def default_config(cls): """Default configuration for :class:`~.OnlineDTActor`.""" return DecisionTransformer.DTConfig( n_embd=512, n_layer=4, n_head=4, n_inner=2048, activation="relu", n_positions=1024, resid_pdrop=0.1, attn_pdrop=0.1, )
[docs]class DTActor(nn.Module): """Decision Transformer Actor class. Actor class for the Decision Transformer to output deterministic action as presented in `"Decision Transformer" <https://arxiv.org/abs/2202.05607.pdf>`. Returns the deterministic actions. Args: state_dim (int): state dimension. action_dim (int): action dimension. transformer_config (Dict or :class:`DecisionTransformer.DTConfig`, optional): config for the GPT2 transformer. Defaults to :meth:`~.default_config`. device (Optional[DEVICE_TYPING], optional): device to use. Defaults to None. Examples: >>> model = DTActor(state_dim=4, action_dim=2, ... transformer_config=DTActor.default_config()) >>> observation = torch.randn(32, 10, 4) >>> action = torch.randn(32, 10, 2) >>> return_to_go = torch.randn(32, 10, 1) >>> output = model(observation, action, return_to_go) >>> output.shape torch.Size([32, 10, 2]) """ def __init__( self, state_dim: int, action_dim: int, transformer_config: Dict | DecisionTransformer.DTConfig = None, device: Optional[DEVICE_TYPING] = None, ): super().__init__() if transformer_config is None: transformer_config = self.default_config() if isinstance(transformer_config, DecisionTransformer.DTConfig): transformer_config = dataclasses.asdict(transformer_config) self.transformer = DecisionTransformer( state_dim=state_dim, action_dim=action_dim, config=transformer_config, ) self.action_layer = nn.Linear( transformer_config["n_embd"], action_dim, device=device ) def weight_init(m): """Custom weight init for Conv2D and Linear layers.""" if isinstance(m, torch.nn.Linear): nn.init.orthogonal_(m.weight.data) if hasattr(m.bias, "data"): m.bias.data.fill_(0.0) self.action_layer.apply(weight_init)
[docs] def forward( self, observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, ) -> torch.Tensor: hidden_state = self.transformer(observation, action, return_to_go) out = self.action_layer(hidden_state) return out
[docs] @classmethod def default_config(cls): """Default configuration for :class:`~.DTActor`.""" return DecisionTransformer.DTConfig( n_embd=512, n_layer=4, n_head=4, n_inner=2048, activation="relu", n_positions=1024, resid_pdrop=0.1, attn_pdrop=0.1, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources