
Source code for torchrl.modules.models.decision_transformer

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import dataclasses

import importlib
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn

_has_transformers = importlib.util.find_spec("transformers") is not None

[docs]class DecisionTransformer(nn.Module): """Online Decision Transformer. Desdescribed in . The transformer utilizes a default config to create the GPT2 model if the user does not provide a specific config. default_config = { ... "n_embd": 256, ... "n_layer": 4, ... "n_head": 4, ... "n_inner": 1024, ... "activation": "relu", ... "n_positions": 1024, ... "resid_pdrop": 0.1, ... "attn_pdrop": 0.1, } Args: state_dim (int): dimension of the state space action_dim (int): dimension of the action space config (:obj:`~.DTConfig` or dict, optional): transformer architecture configuration, used to create the GPT2Config from transformers. Defaults to :obj:`~.default_config`. Example: >>> config = DecisionTransformer.default_config() >>> config.n_embd = 128 >>> print(config) DTConfig(n_embd: 128, n_layer: 4, n_head: 4, n_inner: 1024, activation: relu, n_positions: 1024, resid_pdrop: 0.1, attn_pdrop: 0.1) >>> # alternatively >>> config = DecisionTransformer.DTConfig(n_embd=128) >>> model = DecisionTransformer(state_dim=4, action_dim=2, config=config) >>> batch_size = [3, 32] >>> length = 10 >>> observation = torch.randn(*batch_size, length, 4) >>> action = torch.randn(*batch_size, length, 2) >>> return_to_go = torch.randn(*batch_size, length, 1) >>> output = model(observation, action, return_to_go) >>> output.shape torch.Size([3, 32, 10, 128]) """
[docs] @dataclass class DTConfig: """Default configuration for DecisionTransformer.""" n_embd: Any = 256 n_layer: Any = 4 n_head: Any = 4 n_inner: Any = 1024 activation: Any = "relu" n_positions: Any = 1024 resid_pdrop: Any = 0.1 attn_pdrop: Any = 0.1 def __repr__(self): fields = [] for f in dataclasses.fields(self): value = getattr(self, fields.append(f"{}: {value}") fields = ", ".join(fields) return f"{self.__class__.__name__}({fields})"
@classmethod def default_config(cls): return cls.DTConfig() def __init__( self, state_dim, action_dim, config: dict | DTConfig = None, device: torch.device | None = None, ): if not _has_transformers: raise ImportError( "transformers is not installed. Please install it with `pip install transformers`." ) import transformers from transformers.models.gpt2.modeling_gpt2 import GPT2Model if config is None: config = self.default_config() if isinstance(config, self.DTConfig): config = dataclasses.asdict(config) if not isinstance(config, dict): try: config = dict(config) except Exception as err: raise TypeError( f"Config of type {type(config)} is not supported." ) from err super(DecisionTransformer, self).__init__() with torch.device(device) if device is not None else nullcontext(): gpt_config = transformers.GPT2Config( n_embd=config["n_embd"], n_layer=config["n_layer"], n_head=config["n_head"], n_inner=config["n_inner"], activation_function=config["activation"], n_positions=config["n_positions"], resid_pdrop=config["resid_pdrop"], attn_pdrop=config["attn_pdrop"], vocab_size=1, ) self.state_dim = state_dim self.action_dim = action_dim self.hidden_size = config["n_embd"] self.transformer = GPT2Model(config=gpt_config) self.embed_return = torch.nn.Linear(1, self.hidden_size) self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size) self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_size) self.embed_ln = nn.LayerNorm(self.hidden_size)
[docs] def forward( self, observation: torch.Tensor, action: torch.Tensor, return_to_go: torch.Tensor, ): batch_size, seq_length = observation.shape[:-2], observation.shape[-2] batch_size_orig = batch_size if len(batch_size) != 1: # TODO: vmap over transformer once this is possible observation = observation.view(-1, *observation.shape[-2:]) action = action.view(-1, *action.shape[-2:]) return_to_go = return_to_go.view(-1, *return_to_go.shape[-2:]) batch_size = torch.Size([batch_size.numel()]) # embed each modality with a different head state_embeddings = self.embed_state(observation) action_embeddings = self.embed_action(action) returns_embeddings = self.embed_return(return_to_go) # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) # which works nice in an autoregressive sense since states predict actions stacked_inputs = torch.stack( (returns_embeddings, state_embeddings, action_embeddings), dim=-2 ).reshape(*batch_size, 3 * seq_length, self.hidden_size) stacked_inputs = self.embed_ln(stacked_inputs) # we feed in the input embeddings (not word indices as in NLP) to the model transformer_outputs = self.transformer( inputs_embeds=stacked_inputs, ) x = transformer_outputs["last_hidden_state"] # reshape x so that the second dimension corresponds to the original # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t x = x.reshape(*batch_size, seq_length, 3, self.hidden_size).transpose(-3, -2) if batch_size_orig is batch_size: return x[..., 1, :, :] # only state tokens return x[..., 1, :, :].reshape(*batch_size_orig, *x.shape[-2:])


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources