Shortcuts

Source code for torchrl.envs.custom.llm

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Any, Callable, Literal

import torch

from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
from tensordict.tensorclass import NonTensorData, NonTensorStack
from tensordict.utils import _zip_strict
from torch.utils.data import DataLoader
from torchrl.data.map.hash import SipHash
from torchrl.data.tensor_specs import (
    Bounded,
    Categorical as CategoricalSpec,
    Composite,
    NonTensor,
    TensorSpec,
    Unbounded,
)
from torchrl.envs import EnvBase
from torchrl.envs.utils import _StepMDP


[docs]class LLMEnv(EnvBase): """A text generation environment. This environment is designed to work with language models, where the observation is a string or a tensor of integers representing a sequence of tokens. The action is also a string or a tensor of integers, which is concatenated to the previous observation to form the new observation. By default, this environment is meant to track history for a prompt. Users can append transforms to tailor this to their use case, such as Chain of Thought (CoT) reasoning or other custom processing. Users must append a transform to set the "done" condition, which would trigger the loading of the next prompt. Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader` Keyword Args: token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`). Defaults to ``"tokens"``. str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`). Defaults to ``"text"``. attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. Defaults to ``"attention_mask"``. action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``tokens_response`` or ``"text_response"``. reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. Defaults to ``"reward"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an unbounded vocabulary. Defaults to ``None``. no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed to be the plain output of the LLM, including the input tokens / strings. has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by :attr:`attention_key`. Defaults to ``True``. assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape is written during calls to `step()`. Defaults to ``False``. assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the action shape is written during calls to `step()`. Defaults to ``False``. .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root as it is a requirement for all TorchRL environments. batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. Defaults to ``None`` (batch-unlocked). as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`. Defaults to ``False``. .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples. Methods: from_dataloader: Creates an LLMEnv instance from a dataloader. """ _DEFAULT_TOKEN_KEY = "tokens" _DEFAULT_STR_KEY = "text" _DEFAULT_ATTENTION_KEY = "attention_mask" _DEFAULT_ACTION_TOKENS_KEY = "tokens_response" _DEFAULT_ACTION_STR_KEY = "text_response" def __init__( self, *, token_key: NestedKey | None = None, str_key: NestedKey | None = None, attention_key: NestedKey | None = None, action_key: NestedKey | None = None, reward_key: NestedKey = "reward", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, no_stack: bool = True, assign_reward: bool = False, assign_done: bool = False, batch_size: int | torch.Size | None = None, has_attention: bool = True, as_llm_data: bool = False, ) -> None: self.as_llm_data = as_llm_data if token_key is None: token_key = self._DEFAULT_TOKEN_KEY if str_key is None: str_key = self._DEFAULT_STR_KEY if attention_key is None: attention_key = self._DEFAULT_ATTENTION_KEY if action_key is None: if str2str: action_key = self._DEFAULT_ACTION_STR_KEY else: action_key = self._DEFAULT_ACTION_TOKENS_KEY if batch_size is None: self._batch_locked = False batch_size = () else: self._batch_locked = True if not isinstance(batch_size, (tuple, list)): batch_size = (batch_size,) super().__init__( device=device, batch_size=batch_size, ) self.has_attention = has_attention self.str2str = str2str self.vocab_size = vocab_size self.token_key = unravel_key(token_key) self.str_key = unravel_key(str_key) if attention_key is not None: attention_key = unravel_key(attention_key) self.attention_key = attention_key self.no_stack = no_stack self.assign_reward = assign_reward self.assign_done = assign_done # self.action_key = unravel_key(action_key) if str2str: self.full_observation_spec_unbatched = Composite( { self.str_key: NonTensor( example_data="a string", batched=True, shape=() ) } ) self.full_action_spec_unbatched = Composite( {action_key: NonTensor(example_data="a string", batched=True, shape=())} ) else: if vocab_size is None: observation_spec = { token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device) } if self.has_attention: observation_spec[attention_key] = Unbounded( shape=(-1,), dtype=torch.int64, device=device ) self.full_observation_spec_unbatched = Composite(observation_spec) self.full_action_spec_unbatched = Composite( { action_key: Unbounded( shape=(-1,), dtype=torch.int64, device=device ) } ) else: self.full_observation_spec_unbatched = Composite( { token_key: Bounded( shape=(-1,), dtype=torch.int64, low=0, high=vocab_size, device=device, ) } ) self.full_action_spec_unbatched = Composite( { action_key: Bounded( shape=(-1,), dtype=torch.int64, low=0, high=vocab_size, device=device, ) } ) STR2STR_ERR = ValueError( "str2str cannot be True when either of assign_reward / assign_done are True. " "Tokens are required to compute the reward shape." ) if self.assign_reward: if self.str2str: raise STR2STR_ERR self.full_reward_spec_unbatched = Composite( {reward_key: Unbounded(shape=(-1,), device=device)} ) else: self.full_reward_spec_unbatched = Composite(device=device) if not self.assign_done: # Use single done self.full_done_spec_unbatched = Composite( done=Unbounded(shape=(1,), dtype=torch.bool), terminated=Unbounded(shape=(1,), dtype=torch.bool), ) elif self.str2str: raise STR2STR_ERR else: # Use single done self.full_done_spec_unbatched = Composite( tokens_data=Composite( done=Unbounded(shape=(-1,), dtype=torch.bool), terminated=Unbounded(shape=(-1,), dtype=torch.bool), ), done=Unbounded(shape=(1,), dtype=torch.bool), terminated=Unbounded(shape=(1,), dtype=torch.bool), )
[docs] @classmethod def from_dataloader( cls, dataloader: DataLoader, *, token_key: NestedKey | None = None, str_key: NestedKey | None = None, attention_key: NestedKey | None = None, action_key: NestedKey | None = None, reward_key: NestedKey = "reward", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, no_stack: bool = False, as_llm_data: bool = False, batch_size: int | torch.Size | None = None, has_attention: bool = True, assign_reward: bool = False, assign_done: bool = False, primers: Composite | None = None, data_keys: list[NestedKey] | None = None, data_specs: list[TensorSpec] | None = None, example_data: Any = None, stack_method: Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"] = None, repeats: int | None = None, ) -> LLMEnv: """Creates an LLMEnv instance from a dataloader. This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which populates ``data_keys`` (by default ``observation_key``) with data from the provided dataloader when the environment is reset. Args: dataloader (DataLoader): The dataloader to load data from. token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`). Defaults to ``("tokens_in", "input_ids")``. str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`). Defaults to ``"test"``. attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. Defaults to ``("tokens_in", "input_ids")`` action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``("tokens_out", "sequences")``. reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. Defaults to ``"reward"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an unbounded vocabulary. Defaults to ``None``. no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed to be the plain output of the LLM, including the input tokens / strings. has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by :attr:`attention_key`. Defaults to ``True``. assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape is written during calls to `step()`. Defaults to ``False``. assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the action shape is written during calls to `step()`. Defaults to ``False``. .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root as it is a requirement for all TorchRL environments. batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. Defaults to ``None`` (batch-unlocked). primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to ``None``. data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data. Defaults to ``None``. data_specs (list[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to ``None``. example_data (Any, optional): Example data to use for initializing the primer. Defaults to ``None``. stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``None``. repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo samples (rather than an advantage module). as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`. Defaults to ``False``. Returns: LLMEnv: The created LLMEnv instance. """ from torchrl.envs import DataLoadingPrimer if data_keys is None: if str2str: if str_key is None: data_keys = [LLMEnv._DEFAULT_STR_KEY] else: data_keys = [str_key] else: if token_key is None: data_keys = [LLMEnv._DEFAULT_TOKEN_KEY] else: data_keys = [token_key] if has_attention: if attention_key is None: data_keys.append(LLMEnv._DEFAULT_ATTENTION_KEY) else: data_keys.append(attention_key) primer = DataLoadingPrimer( dataloader=dataloader, primers=primers, data_keys=data_keys, data_specs=data_specs, example_data=example_data, stack_method=stack_method, repeats=repeats, ) env = LLMEnv( str2str=str2str, device=device, token_key=token_key, str_key=str_key, attention_key=attention_key, action_key=action_key, reward_key=reward_key, vocab_size=vocab_size, no_stack=no_stack, assign_reward=assign_reward, assign_done=assign_done, batch_size=batch_size, has_attention=has_attention, as_llm_data=as_llm_data, ) return env.append_transform(primer)
@staticmethod def _check_obs_act_and_cat(obs, action): if not isinstance(obs, str): raise TypeError(f"Observation must be a string, got {type(obs)}.") if not isinstance(action, str): raise TypeError(f"Action must be a string, got {type(action)}.") return obs + action def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: next_td = tensordict.empty() self._make_next_obs(tensordict, next_td) self._maybe_make_reward(tensordict, next_td) self._maybe_make_done(tensordict, next_td) if self.as_llm_data: raise NotImplementedError() return next_td def _maybe_make_reward( self, tensordict: TensorDictBase, next_td: TensorDictBase ) -> TensorDictBase: if self.assign_reward: next_td.set( self.reward_key, torch.zeros_like( tensordict.get(self.action_key), dtype=self.reward_spec.dtype ), ) return next_td def _maybe_make_done( self, tensordict: TensorDictBase, next_td: TensorDictBase ) -> TensorDictBase: if self.assign_done: action = tensordict.get(self.action_key) if action is None: done = torch.zeros( tensordict.shape + (1,), dtype=torch.bool, device=self.device ) else: done = torch.zeros_like(action, dtype=torch.bool) next_td.set(("tokens_data", "terminated"), done) next_td.set(("tokens_data", "done"), done.clone()) next_td.set( "terminated", next_td.get(("tokens_data", "done")).any(-1, keepdim=True) ) next_td.set( "terminated", next_td.get(("tokens_data", "terminated")).any(-1, keepdim=True), ) return next_td def _make_next_obs( self, tensordict: TensorDictBase, nex_td: TensorDictBase ) -> TensorDictBase: if self.no_stack: if self.str2str: raise NotImplementedError action = tensordict.get(self.action_key) nex_td.set(self.token_key, action) if self.has_attention: attention_mask = tensordict.get(self.attention_key) n = action.shape[-1] - attention_mask.shape[-1] if n > 0: # It can happen that there's only one action (eg rand_action) attention_mask = torch.cat( [ attention_mask, attention_mask.new_ones(attention_mask.shape[:-1] + (n,)), ], -1, ) nex_td.set(self.attention_key, attention_mask) return nex_td # Cat action entry with prev obs if self.str2str: obs = tensordict[self.str_key] action = tensordict[self.action_key] if not tensordict.batch_size: if not isinstance(obs, str) or not isinstance(action, str): raise TypeError( "The tensordict is batchless, yet the action and/or observations are not " f"strings but {type(action)} and {type(obs)}, respectivly." ) observation = self._check_obs_act_and_cat(obs, action) else: observation = NonTensorStack( *[ self._check_obs_act_and_cat(_obs, _action) for (_obs, _action) in _zip_strict(obs, action) ] ) return nex_td.set(self.str_key, observation) else: try: obs: torch.Tensor = tensordict.get(self.token_key) action = tensordict.get(self.action_key) if getattr(obs, "is_nested", False): observation = torch.nested.as_nested_tensor( [ torch.cat([_obs, _action], -1) for _obs, _action in _zip_strict( obs.unbind(0), action.unbind(0) ) ], layout=obs.layout, ) else: observation = torch.cat([obs, action], -1) except TypeError: raise TypeError( "Failed to cat action and observation tensors. Check that str2str argument is correctly " f"set in {type(self).__name__}." ) return nex_td.set(self.token_key, observation) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # We should have an observation by this time, if not raise an exception def check_token(): return not self.str2str and ( self.token_key not in tensordict.keys(isinstance(self.token_key, tuple)) ) def check_str(): return self.str2str and ( self.str_key not in tensordict.keys(isinstance(self.str_key, tuple)) ) if tensordict is None or check_token() or check_str(): raise KeyError( f"Observation key {self.token_key} is not defined. Make sure a TensorDictPrimer (eg, " f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." ) td_reset = tensordict.copy() tensordict = self._maybe_make_done(tensordict, td_reset) if self.as_llm_data: raise NotImplementedError() return tensordict def _set_seed(self, seed: int | None): return seed
[docs]class LLMHashingEnv(EnvBase): """A text generation environment that uses a hashing module to identify unique observations. The primary goal of this environment is to identify token chains using a hashing function. This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node identifiers, or easily prune repeated token chains in a data structure. The following figure gives an overview of this workflow: .. figure:: /_static/img/rollout-llm.png :alt: Data collection loop with our LLM environment. Args: vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed. Keyword Args: hashing_module (Callable[[torch.Tensor], torch.Tensor], optional): A hashing function that takes a tensor as input and returns a hashed tensor. Defaults to :class:`~torchrl.data.SipHash` if not provided. observation_key (NestedKey, optional): The key for the observation in the TensorDict. Defaults to "observation". text_output (bool, optional): Whether to include the text output in the observation. Defaults to True. tokenizer (transformers.Tokenizer | None, optional): A tokenizer function that converts text to tensors. Only used when `text_output` is `True`. Must implement the following methods: `decode` and `batch_decode`. Defaults to ``None``. text_key (NestedKey | None, optional): The key for the text output in the TensorDict. Defaults to "text". Examples: >>> from tensordict import TensorDict >>> from torchrl.envs import LLMHashingEnv >>> from transformers import GPT2Tokenizer >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") >>> x = tokenizer(["Check out TorchRL!"])["input_ids"] >>> env = LLMHashingEnv(tokenizer=tokenizer) >>> td = TensorDict(observation=x, batch_size=[1]) >>> td = env.reset(td) >>> print(td) TensorDict( fields={ done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False), terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), text: NonTensorStack( ['Check out TorchRL!'], batch_size=torch.Size([1]), device=None)}, batch_size=torch.Size([1]), device=None, is_shared=False) """ def __init__( self, vocab_size: int | None = None, *, hashing_module: Callable[[torch.Tensor], torch.Tensor] = None, observation_key: NestedKey = "observation", text_output: bool = True, tokenizer: Callable[[str | list[str]], torch.Tensor] | None = None, text_key: NestedKey | None = "text", ): super().__init__() if vocab_size is None: if tokenizer is None: raise TypeError( "You must provide a vocab_size integer if tokenizer is `None`." ) vocab_size = tokenizer.vocab_size self._batch_locked = False if hashing_module is None: hashing_module = SipHash() self._hashing_module = hashing_module self._tokenizer = tokenizer self.observation_key = observation_key observation_spec = { observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)), "hashing": Unbounded(shape=(1,), dtype=torch.int64), } self.text_output = text_output if not text_output: text_key = None elif text_key is None: text_key = "text" if text_key is not None: observation_spec[text_key] = NonTensor(shape=()) self.text_key = text_key self.observation_spec = Composite(observation_spec) self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,))) _StepMDP(self)
[docs] def make_tensordict(self, input: str | list[str]) -> TensorDict: """Converts a string or list of strings in a TensorDict with appropriate shape and device.""" list_len = len(input) if isinstance(input, list) else 0 tensordict = TensorDict( {self.observation_key: self._tokenizer(input)}, device=self.device ) if list_len: tensordict.batch_size = [list_len] return self.reset(tensordict)
def _reset(self, tensordict: TensorDictBase): """Initializes the environment with a given observation. Args: tensordict (TensorDictBase): A TensorDict containing the initial observation. Returns: A TensorDict containing the initial observation, its hash, and other relevant information. """ out = tensordict.empty() obs = tensordict.get(self.observation_key, None) if obs is None: raise RuntimeError( f"Resetting the {type(self).__name__} environment requires a prompt." ) if self.text_output: if obs.ndim > 1: text = self._tokenizer.batch_decode(obs) text = NonTensorStack.from_list(text) else: text = self._tokenizer.decode(obs) text = NonTensorData(text) out.set(self.text_key, text) if obs.ndim > 1: out.set("hashing", self._hashing_module(obs).unsqueeze(-1)) else: out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1)) if not self.full_done_spec.is_empty(): out.update(self.full_done_spec.zero(tensordict.shape)) else: out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)) out.set( "terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool) ) return out def _step(self, tensordict): """Takes an action (i.e., the next token to generate) and returns the next observation and reward. Args: tensordict: A TensorDict containing the current observation and action. Returns: A TensorDict containing the next observation, its hash, and other relevant information. """ out = tensordict.empty() action = tensordict.get("action") obs = torch.cat([tensordict.get(self.observation_key), action], -1) kwargs = {self.observation_key: obs} catval = torch.cat([tensordict.get("hashing"), action], -1) if obs.ndim > 1: new_hash = self._hashing_module(catval).unsqueeze(-1) else: new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1) if self.text_output: if obs.ndim > 1: text = self._tokenizer.batch_decode(obs) text = NonTensorStack.from_list(text) else: text = self._tokenizer.decode(obs) text = NonTensorData(text) kwargs[self.text_key] = text kwargs.update( { "hashing": new_hash, "done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool), "terminated": torch.zeros( (*tensordict.batch_size, 1), dtype=torch.bool ), } ) return out.update(kwargs) def _set_seed(self, *args): """Sets the seed for the environment's randomness. .. note:: This environment has no randomness, so this method does nothing. """

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources