Shortcuts

DataLoadingPrimer

class torchrl.envs.transforms.DataLoadingPrimer(dataloader: Iterable[Any], *, primers: Composite | None = None, data_keys: list[NestedKey] | None = None, data_specs: list[TensorSpec] | None = None, example_data: Any = None, stack_method: Callable[[Any], Any] | Literal['as_nested_tensor', 'as_padded_tensor'] = None, use_buffer: bool | None = None, auto_batch_size: bool = True, repeats: int | None = None)[source]

A primer that loads data from a dataloader and converts it into a tensordict using stack_method.

Parameters:

dataloader (Iterable[Any]) – The dataloader to load data from.

Keyword Arguments:
  • primers (Composite | None, optional) – The primers to use for each key in the dataloader. Defaults to None.

  • data_keys (List[NestedKey] | None, optional) – The keys to use for each item in the dataloader. Defaults to None.

  • data_specs (List[TensorSpec] | None, optional) – The specs to use for each item in the dataloader. Defaults to None.

  • example_data (Any, optional) – Example data to use for initializing the primer. Defaults to None.

  • stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional) – The method to use for stacking the data. Defaults to maybe_dense_stack.

  • use_buffer (bool, optional) – Whether to use a buffer to load the batches. When an environment has a batch-size that differs from the dataloader’s, or when partial resets are to be expected, using a buffer to store data ensures that next() is called on the dataloader only when necessary, and that elements of the dataset are loaded in order. Defaults to True whenever the batch-size of the dataloader is greater than 1.

  • auto_batch_size (bool, optional) – If True (default if dataloader.batch_size > 0), the batch size of the tensordict returned by the transform will be automatically determined assuming that there is a single batch dimension.

  • repeats (int, optional) – How many times the same sample needs to appear successively. This can be useful in situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo samples (rather than an advantage module).

Variables:
  • dataloader (Iterable[Any]) – The dataloader to load data from.

  • endless_dataloader (Iterable[Any]) – An endless iterator over the dataloader.

  • data_keys (List[NestedKey]) – The keys to use for each item in the dataloader.

  • stack_method (Callable[[Any], Any]) – The method to use for stacking the data.

See also

LLMEnv and from_dataloader.

Example of a dataloader yielding strings:
>>> import random
>>> import string
>>> import tensordict as td
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Unbounded
>>> from torchrl.envs import DataLoadingPrimer, LLMEnv
>>> td.set_capture_non_tensor_stack(False).set()
>>> class DummyDataLoader:
...     '''A dummy dataloader that generates random strings.'''
...     def __init__(self, batch_size: int = 0):
...         self.batch_size = batch_size
...     def generate_random_string(self, length: int = 10) -. str:
...         '''Generate a random string of a given length.'''
...         return ''.join(random.choice(string.ascii_lowercase) for _ in range(length))
...     def __iter__(self):
...         return self
...     def __next__(self):
...         if self.batch_size == 0:
...             return self.generate_random_string()
...         else:
...             return [self.generate_random_string() for _ in range(self.batch_size)]
>>> # Create an LLM environment with string-to-string input/output.
>>> env = LLMEnv(str2str=True)
>>> # Append a DataLoadingPrimer to the environment.
>>> env = env.append_transform(
>>>     DataLoadingPrimer(
>>>         dataloader=DummyDataLoader(),
>>>         data_keys=["observation"],
>>>         example_data="a string!",
>>>     )
>>> )
>>> # Test the environment.
>>> print(env.rand_action(TensorDict()))
TensorDict(
    fields={
        action: NonTensorData(data=a string, batch_size=torch.Size([]), device=None)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(env.rollout(3))
TensorDict(
    fields={
        action: NonTensorStack(
            ['a string', 'a string', 'a string'],
            batch_size=torch.Size([3]),
            device=None),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: NonTensorStack(
                    ['zxwvupirska string', 'zxwvupirska stringa string...,
                    batch_size=torch.Size([3]),
                    device=None),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        observation: NonTensorStack(
            ['zxwvupirsk', 'zxwvupirska string', 'zxwvupirska ...,
            batch_size=torch.Size([3]),
            device=None),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> # Roll out the environment with a specific initial state.
>>> init_state = env.reset(TensorDict(batch_size=[3]))
>>> print(env.rollout(3, auto_reset=False, tensordict=init_state))
TensorDict(
    fields={
        action: NonTensorStack(
            [['a string', 'a string', 'a string'], ['a string'...,
            batch_size=torch.Size([3, 3]),
            device=None),
        done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: NonTensorStack(
                    [[array(['nngcmflsana string', 'vrrbnhzpmga string...,
                    batch_size=torch.Size([3, 3]),
                    device=None),
                terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3, 3]),
            device=None,
            is_shared=False),
        observation: NonTensorStack(
            [['nngcmflsan', array(['nngcmflsana string', 'vrrb...,
            batch_size=torch.Size([3, 3]),
            device=None),
        terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3, 3]),
    device=None,
    is_shared=False)
Example of dataloader yielding tensors:
>>> import random
>>> import string
>>>
>>> import tensordict as td
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Unbounded
>>> from torchrl.envs import DataLoadingPrimer, LLMEnv
>>>
>>> td.set_capture_non_tensor_stack(False).set()
>>>
>>>
>>> class DummyTensorDataLoader:
...     '''A dummy dataloader that generates tensors of random int64 values.'''
...
...     def __init__(self, batch_size: int = 0, max_length: int = 10, padding: bool = False):
...         '''
...         Args:
...             batch_size (int, optional): The batch size of the generated tensors. Defaults to 0.
...             max_length (int, optional): The maximum length of the generated tensors. Defaults to 10.
...             padding (bool, optional): Whether to pad the tensors to the maximum length. Defaults to False.
...         '''
...         self.batch_size = batch_size
...         self.max_length = max_length
...         self.padding = padding
...
...     def generate_random_tensor(self) -. torch.Tensor:
...         '''Generate a tensor of random int64 values.'''
...         length = random.randint(1, self.max_length)
...         return torch.tensor([random.randint(0, 100) for _ in range(length)], dtype=torch.int64)
...
...     def pad_tensor(self, tensor: torch.Tensor) -. torch.Tensor:
...         '''Pad a tensor to the maximum length.'''
...         padding_length = self.max_length - len(tensor)
...         return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))
...
...     def __iter__(self):
...         return self
...
...     def __next__(self):
...         if self.batch_size == 0:
...             tensor = self.generate_random_tensor()
...             return self.pad_tensor(tensor) if self.padding else tensor
...         else:
...             tensors = [self.generate_random_tensor() for _ in range(self.batch_size)]
...             if self.padding:
...                 tensors = [self.pad_tensor(tensor) for tensor in tensors]
...                 return torch.stack(tensors)
...             else:
...                 return tensors
>>>
>>> # Create an LLM environment with non-string input/output and append a DataLoadingPrimer.
>>> env = LLMEnv(str2str=False)
>>> env = env.append_transform(
>>>     DataLoadingPrimer(
>>>         dataloader=DummyTensorDataLoader(),
>>>         data_keys=["observation"],
>>>         data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
>>>     )
>>> )
>>> print(env.rand_action(TensorDict()))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> print(env.rollout(3))
LazyStackedTensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: LazyStackedTensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False,
            stack_dim=0),
        observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False,
    stack_dim=0)
>>> # Create an LLM environment with padded tensor input/output and append a DataLoadingPrimer.
>>> env = LLMEnv(str2str=False)
>>> env = env.append_transform(
>>>     DataLoadingPrimer(
>>>         dataloader=DummyTensorDataLoader(padding=True),
>>>         data_keys=["observation"],
>>>         data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
>>>         stack_method="as_padded_tensor",
>>>     )
>>> )
>>> print(env.rollout(3, auto_reset=False, tensordict=env.reset(TensorDict(batch_size=[3]))))
LazyStackedTensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: LazyStackedTensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([3, 3]),
            device=None,
            is_shared=False,
            stack_dim=1),
        observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    exclusive_fields={
    },
    batch_size=torch.Size([3, 3]),
    device=None,
    is_shared=False,
    stack_dim=1)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources