# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections import deque
from collections.abc import Mapping
from copy import copy, deepcopy
from typing import Any, Callable, Iterable, Literal
import torch
from tensordict import (
maybe_dense_stack,
NestedKey,
TensorDict,
TensorDictBase,
unravel_key,
)
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
from tensordict.utils import _zip_strict, is_seq_of_nested_key
from torch import nn
from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
from torchrl.envs.utils import make_composite_from_td
def as_nested_tensor(list_of_tensordicts: list[TensorDictBase]) -> TensorDictBase:
"""Stacks a list of tensordicts into a single tensordict with nested tensors.
Args:
list_of_tensordicts (list[TensorDictBase]): A list of tensordicts to stack.
Returns:
TensorDictBase: A tensordict with nested tensors.
"""
def _as_nested_tensor(*list_of_tensors):
return torch.nested.as_nested_tensor(list_of_tensors, layout=torch.jagged)
batch_size = list(list_of_tensordicts[0].shape)
batch_size.insert(0, len(list_of_tensordicts))
return list_of_tensordicts[0].apply(
_as_nested_tensor, *list_of_tensordicts[1:], batch_size=batch_size
)
def as_padded_tensor(
list_of_tensordicts: list[[TensorDictBase]], dim=0, stack_dim: int = 0
) -> TensorDictBase:
"""Stacks a list of tensordicts into a single tensordict with padded tensors.
Args:
list_of_tensordicts (list[[TensorDictBase]]): A list of tensordicts to stack.
dim (int, optional): The dimension along which to pad. Defaults to 0.
stack_dim (int, optional): The dimension along which to stack. Defaults to 0.
Returns:
TensorDictBase: A tensordict with padded tensors.
"""
def _stack_tensors(*list_of_tensors):
if dim < 0:
raise ValueError("dim must be >= 0")
max_length = max([t.size(dim) for t in list_of_tensors])
def pad_tensor(tensor):
padding_length = max_length - tensor.size(dim)
shape = [
s if i != dim else padding_length for i, s in enumerate(tensor.shape)
]
return torch.cat((tensor.new_zeros(shape), tensor), dim=dim)
return torch.stack([pad_tensor(t) for t in list_of_tensors], dim=stack_dim)
batch_size = list(list_of_tensordicts[0].shape)
batch_size.insert(dim, len(list_of_tensordicts))
result = list_of_tensordicts[0].apply(
_stack_tensors, *list_of_tensordicts[1:], batch_size=batch_size
)
return result
[docs]class DataLoadingPrimer(TensorDictPrimer):
"""A primer that loads data from a dataloader and converts it into a tensordict using ``stack_method``.
Args:
dataloader (Iterable[Any]): The dataloader to load data from.
Keyword Args:
primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None.
data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None.
data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None.
example_data (Any, optional): Example data to use for initializing the primer. Defaults to None.
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``.
use_buffer (bool, optional): Whether to use a buffer to load the batches. When an environment has a batch-size
that differs from the dataloader's, or when partial resets are to be expected, using a buffer to store data
ensures that `next()` is called on the dataloader only when necessary, and that elements of the dataset
are loaded in order.
Defaults to ``True`` whenever the batch-size of the dataloader is greater than 1.
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
tensordict returned by the transform will be automatically determined assuming that there is a single batch
dimension.
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
samples (rather than an advantage module).
Attributes:
dataloader (Iterable[Any]): The dataloader to load data from.
endless_dataloader (Iterable[Any]): An endless iterator over the dataloader.
data_keys (List[NestedKey]): The keys to use for each item in the dataloader.
stack_method (Callable[[Any], Any]): The method to use for stacking the data.
.. seealso:: :class:`~torchrl.envs.LLMEnv` and :class:`~torchrl.envs.LLMEnv.from_dataloader`.
Example of a dataloader yielding strings:
>>> import random
>>> import string
>>> import tensordict as td
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Unbounded
>>> from torchrl.envs import DataLoadingPrimer, LLMEnv
>>> td.set_capture_non_tensor_stack(False).set()
>>> class DummyDataLoader:
... '''A dummy dataloader that generates random strings.'''
... def __init__(self, batch_size: int = 0):
... self.batch_size = batch_size
... def generate_random_string(self, length: int = 10) -. str:
... '''Generate a random string of a given length.'''
... return ''.join(random.choice(string.ascii_lowercase) for _ in range(length))
... def __iter__(self):
... return self
... def __next__(self):
... if self.batch_size == 0:
... return self.generate_random_string()
... else:
... return [self.generate_random_string() for _ in range(self.batch_size)]
>>> # Create an LLM environment with string-to-string input/output.
>>> env = LLMEnv(str2str=True)
>>> # Append a DataLoadingPrimer to the environment.
>>> env = env.append_transform(
>>> DataLoadingPrimer(
>>> dataloader=DummyDataLoader(),
>>> data_keys=["observation"],
>>> example_data="a string!",
>>> )
>>> )
>>> # Test the environment.
>>> print(env.rand_action(TensorDict()))
TensorDict(
fields={
action: NonTensorData(data=a string, batch_size=torch.Size([]), device=None)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(env.rollout(3))
TensorDict(
fields={
action: NonTensorStack(
['a string', 'a string', 'a string'],
batch_size=torch.Size([3]),
device=None),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: NonTensorStack(
['zxwvupirska string', 'zxwvupirska stringa string...,
batch_size=torch.Size([3]),
device=None),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False),
observation: NonTensorStack(
['zxwvupirsk', 'zxwvupirska string', 'zxwvupirska ...,
batch_size=torch.Size([3]),
device=None),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> # Roll out the environment with a specific initial state.
>>> init_state = env.reset(TensorDict(batch_size=[3]))
>>> print(env.rollout(3, auto_reset=False, tensordict=init_state))
TensorDict(
fields={
action: NonTensorStack(
[['a string', 'a string', 'a string'], ['a string'...,
batch_size=torch.Size([3, 3]),
device=None),
done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: NonTensorStack(
[[array(['nngcmflsana string', 'vrrbnhzpmga string...,
batch_size=torch.Size([3, 3]),
device=None),
terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 3]),
device=None,
is_shared=False),
observation: NonTensorStack(
[['nngcmflsan', array(['nngcmflsana string', 'vrrb...,
batch_size=torch.Size([3, 3]),
device=None),
terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 3]),
device=None,
is_shared=False)
Example of dataloader yielding tensors:
>>> import random
>>> import string
>>>
>>> import tensordict as td
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Unbounded
>>> from torchrl.envs import DataLoadingPrimer, LLMEnv
>>>
>>> td.set_capture_non_tensor_stack(False).set()
>>>
>>>
>>> class DummyTensorDataLoader:
... '''A dummy dataloader that generates tensors of random int64 values.'''
...
... def __init__(self, batch_size: int = 0, max_length: int = 10, padding: bool = False):
... '''
... Args:
... batch_size (int, optional): The batch size of the generated tensors. Defaults to 0.
... max_length (int, optional): The maximum length of the generated tensors. Defaults to 10.
... padding (bool, optional): Whether to pad the tensors to the maximum length. Defaults to False.
... '''
... self.batch_size = batch_size
... self.max_length = max_length
... self.padding = padding
...
... def generate_random_tensor(self) -. torch.Tensor:
... '''Generate a tensor of random int64 values.'''
... length = random.randint(1, self.max_length)
... return torch.tensor([random.randint(0, 100) for _ in range(length)], dtype=torch.int64)
...
... def pad_tensor(self, tensor: torch.Tensor) -. torch.Tensor:
... '''Pad a tensor to the maximum length.'''
... padding_length = self.max_length - len(tensor)
... return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))
...
... def __iter__(self):
... return self
...
... def __next__(self):
... if self.batch_size == 0:
... tensor = self.generate_random_tensor()
... return self.pad_tensor(tensor) if self.padding else tensor
... else:
... tensors = [self.generate_random_tensor() for _ in range(self.batch_size)]
... if self.padding:
... tensors = [self.pad_tensor(tensor) for tensor in tensors]
... return torch.stack(tensors)
... else:
... return tensors
>>>
>>> # Create an LLM environment with non-string input/output and append a DataLoadingPrimer.
>>> env = LLMEnv(str2str=False)
>>> env = env.append_transform(
>>> DataLoadingPrimer(
>>> dataloader=DummyTensorDataLoader(),
>>> data_keys=["observation"],
>>> data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
>>> )
>>> )
>>> print(env.rand_action(TensorDict()))
TensorDict(
fields={
action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(env.rollout(3))
LazyStackedTensorDict(
fields={
action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: LazyStackedTensorDict(
fields={
done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([3]),
device=None,
is_shared=False,
stack_dim=0),
observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([3]),
device=None,
is_shared=False,
stack_dim=0)
>>> # Create an LLM environment with padded tensor input/output and append a DataLoadingPrimer.
>>> env = LLMEnv(str2str=False)
>>> env = env.append_transform(
>>> DataLoadingPrimer(
>>> dataloader=DummyTensorDataLoader(padding=True),
>>> data_keys=["observation"],
>>> data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
>>> stack_method="as_padded_tensor",
>>> )
>>> )
>>> print(env.rollout(3, auto_reset=False, tensordict=env.reset(TensorDict(batch_size=[3]))))
LazyStackedTensorDict(
fields={
action: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: LazyStackedTensorDict(
fields={
done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([3, 3]),
device=None,
is_shared=False,
stack_dim=1),
observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([3, 3]),
device=None,
is_shared=False,
stack_dim=1)
"""
def __init__(
self,
dataloader: Iterable[Any],
*,
primers: Composite | None = None,
data_keys: list[NestedKey] | None = None,
data_specs: list[TensorSpec] | None = None,
example_data: Any = None,
stack_method: Callable[[Any], Any]
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
use_buffer: bool | None = None,
auto_batch_size: bool = True,
repeats: int | None = None,
):
self.dataloader = dataloader
if repeats is None:
repeats = 0
self.repeats = repeats
if (
getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None
) or repeats > 0:
use_buffer = True
self.use_buffer = use_buffer
if self.use_buffer:
self._queue = deque()
# No auto_batch_size if we know we have a single element
self.auto_batch_size = auto_batch_size and (
getattr(dataloader, "batch_size", 1) > 0
)
self.endless_dataloader = self._endless_iter(self.dataloader)
if stack_method is None:
stack_method = maybe_dense_stack
elif stack_method == "as_nested_tensor":
stack_method = as_nested_tensor
elif stack_method == "as_padded_tensor":
stack_method = as_padded_tensor
elif not callable(stack_method):
raise ValueError(f"Unknown stack_method={stack_method}")
self.stack_method = stack_method
if primers is None and not self.use_buffer:
if data_keys is None:
data_keys = ["data"]
if data_specs is None:
data_specs = [NonTensor(example_data=example_data, shape=())]
primers = Composite(
{
data_key: data_spec
for data_key, data_spec in _zip_strict(data_keys, data_specs)
}
)
self.data_keys = data_keys
elif primers is None:
self.data_keys = data_keys
# We can get the primer from the dataloader itself
data = self._load_from_dataloader()
primers = make_composite_from_td(data, dynamic_shape=True)
self._queue.insert(0, data)
if data_keys is None:
self.data_keys = list(primers.keys(True, True))
else:
self.data_keys = list(primers.keys(True, True))
super().__init__(
primers=primers,
default_value=self._load_from_dataloader,
reset_key=None,
expand_specs=None,
single_default_value=True,
call_before_env_reset=True,
)
self._reset_key = "_reset"
@classmethod
def _endless_iter(self, obj):
while True:
yield from obj
def _load_from_dataloader(self, reset: torch.Tensor | None = None):
"""Loads a single element from the dataloader, or alternatively from the buffer.
If `reset` is passed, the one element per reset will be loaded.
"""
if reset is not None:
if not reset.any():
raise RuntimeError("reset must have at least one True value.")
if reset.ndim > 0:
loaded = [self._load_from_dataloader() for i in range(reset.sum())]
return self.stack_method(loaded)
if self.use_buffer and len(self._queue) > 0:
result = self._queue.popleft()
return result
data = next(self.endless_dataloader)
# Some heuristic here:
# if data is a map, assume its keys match the keys in spec
# TODO: one could rename the keys too
if isinstance(data, Mapping):
out = TensorDict.from_dict(
data, auto_batch_size=self.auto_batch_size, batch_dims=1
)
elif self.data_keys is None:
raise RuntimeError(
f"Cannot lazily instantiate the {type(self).__name__} as the data_keys was "
f"not passed but the data is not a Mapping, therefore the keys cannot be retrieved "
f"automatically. Please pass the data_keys to the constructor."
)
elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)):
out = TensorDict.from_dict(
{k: val for k, val in _zip_strict(self.data_keys, data)},
auto_batch_size=self.auto_batch_size,
batch_dims=1,
)
elif len(self.data_keys) == 1:
out = TensorDict.from_dict(
{self.data_keys[0]: data},
auto_batch_size=self.auto_batch_size,
batch_dims=1,
)
else:
raise ValueError(
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
)
if self.use_buffer:
if not out.ndim:
out = out.unsqueeze(0)
self._queue.extend(
[d for d in out.unbind(0) for _ in range(max(1, self.repeats))]
)
return self._queue.popleft()
return out