Shortcuts

TED2Flat

class torchrl.data.TED2Flat(done_key=('next', 'done'), shift_key='shift', is_full_key='is_full', done_keys=('done', 'truncated', 'terminated'), reward_keys=('reward',))[source]

A storage saving hook to serialize TED data in a compact format.

Parameters:
  • done_key (NestedKey, optional) – the key where the done states should be read. Defaults to ("next", "done").

  • shift_key (NestedKey, optional) – the key where the shift will be written. Defaults to “shift”.

  • is_full_key (NestedKey, optional) – the key where the is_full attribute will be written. Defaults to “is_full”.

  • done_keys (Tuple[NestedKey], optional) – a tuple of nested keys indicating the done entries. Defaults to (“done”, “truncated”, “terminated”)

  • reward_keys (Tuple[NestedKey], optional) – a tuple of nested keys indicating the reward entries. Defaults to (“reward”,)

Examples

>>> import tempfile
>>>
>>> from tensordict import TensorDict
>>>
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage
>>> from torchrl.envs import GymEnv
>>> import torch
>>>
>>> env = GymEnv("CartPole-v1")
>>> env.set_seed(0)
>>> torch.manual_seed(0)
>>> collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200)
>>> rb = ReplayBuffer(storage=LazyMemmapStorage(200))
>>> rb.register_save_hook(TED2Flat())
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     for i, data in enumerate(collector):
...         rb.extend(data)
...         rb.dumps(tmpdir)
...     # load the data to represent it
...     td = TensorDict.load(tmpdir + "/storage/")
...     print(td)
TensorDict(
    fields={
        action: MemoryMappedTensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=True),
        collector: TensorDict(
            fields={
                traj_ids: MemoryMappedTensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=True)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        done: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
        observation: MemoryMappedTensor(shape=torch.Size([220, 4]), device=cpu, dtype=torch.float32, is_shared=True),
        reward: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=True),
        terminated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True),
        truncated: MemoryMappedTensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=True)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources