Note
Go to the end to download the full example code.
Get started with data collection and storage¶
Author: Vincent Moens
Note
To run this tutorial in a notebook, add an installation cell at the beginning containing:
!pip install tensordict !pip install torchrl
There is no learning without data. In supervised learning, users are
accustomed to using DataLoader
and the like
to integrate data in their training loop.
Dataloaders are iterable objects that provide you with the data that you will
be using to train your model.
TorchRL approaches the problem of dataloading in a similar manner, although
it is surprisingly unique in the ecosystem of RL libraries. TorchRL’s
dataloaders are referred to as DataCollectors
. Most of the time,
data collection does not stop at the collection of raw data,
as the data needs to be stored temporarily in a buffer
(or equivalent structure for on-policy algorithms) before being consumed
by the loss module. This tutorial will explore
these two classes.
Data collectors¶
The primary data collector discussed here is the
SyncDataCollector
, which is the focus of this
documentation. At a fundamental level, a collector is a straightforward
class responsible for executing your policy within the environment,
resetting the environment when necessary, and providing batches of a
predefined size. Unlike the rollout()
method
demonstrated in the env tutorial, collectors do not
reset between consecutive batches of data. Consequently, two successive
batches of data may contain elements from the same trajectory.
The basic arguments you need to pass to your collector are the size of the
batches you want to collect (frames_per_batch
), the length (possibly
infinite) of the iterator, the policy and the environment. For simplicity,
we will use a dummy, random policy in this example.
import torch
torch.manual_seed(0)
from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs.utils import RandomPolicy
env = GymEnv("CartPole-v1")
env.set_seed(0)
policy = RandomPolicy(env.action_spec)
collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1)
We now expect that our collector will deliver batches of size 200
no
matter what happens during collection. In other words, we may have multiple
trajectories in this batch! The total_frames
indicates how long the
collector should be. A value of -1
will produce a never
ending collector.
Let’s iterate over the collector to get a sense of what this data looks like:
for data in collector:
print(data)
break
TensorDict(
fields={
action: Tensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False),
collector: TensorDict(
fields={
traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([200]),
device=None,
is_shared=False),
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=None,
is_shared=False),
observation: Tensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([200]),
device=None,
is_shared=False)
As you can see, our data is augmented with some collector-specific metadata
grouped in a "collector"
sub-tensordict that we did not see during
environment rollouts. This is useful to keep track of
the trajectory ids. In the following list, each item marks the trajectory
number the corresponding transition belongs to:
print(data["collector", "traj_ids"])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9,
9, 9, 9, 9, 9, 9, 9, 9])
Data collectors are very useful when it comes to coding state-of-the-art
algorithms, as performance is usually measured by the capability of a
specific technique to solve a problem in a given number of interactions with
the environment (the total_frames
argument in the collector).
For this reason, most training loops in our examples look like this:
>>> for data in collector:
... # your algorithm here
Replay Buffers¶
Now that we have explored how to collect data, we would like to know how to store it. In RL, the typical setting is that the data is collected, stored temporarily and cleared after a little while given some heuristic: first-in first-out or other. A typical pseudo-code would look like this:
>>> for data in collector:
... storage.store(data)
... for i in range(n_optim):
... sample = storage.sample()
... loss_val = loss_fn(sample)
... loss_val.backward()
... optim.step() # etc
The parent class that stores the data in TorchRL
is referred to as ReplayBuffer
. TorchRL’s replay
buffers are composable: you can edit the storage type, their sampling
technique, the writing heuristic or the transforms applied to them. We will
leave the fancy stuff for a dedicated in-depth tutorial. The generic replay
buffer only needs to know what storage it has to use. In general, we
recommend a TensorStorage
subclass, which will work
fine in most cases. We’ll be using
LazyMemmapStorage
in this tutorial, which enjoys two nice properties: first, being “lazy”,
you don’t need to explicitly tell it what your data looks like in advance.
Second, it uses MemoryMappedTensor
as a backend to save
your data on disk in an efficient way. The only thing you need to know is
how big you want your buffer to be.
from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer
buffer = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000))
Populating the buffer can be done via the
add()
(single element) or
extend()
(multiple elements) methods. Using
the data we just collected, we initialize and populate the buffer in one go:
indices = buffer.extend(data)
We can check that the buffer now has the same number of elements than what we got from the collector:
assert len(buffer) == collector.frames_per_batch
The only thing left to know is how to gather data from the buffer.
Naturally, this relies on the sample()
method. Because we did not specify that sampling had to be done without
repetitions, it is not guaranteed that the samples gathered from our buffer
will be unique:
sample = buffer.sample(batch_size=30)
print(sample)
TensorDict(
fields={
action: Tensor(shape=torch.Size([30, 2]), device=cpu, dtype=torch.int64, is_shared=False),
collector: TensorDict(
fields={
traj_ids: Tensor(shape=torch.Size([30]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([30]),
device=cpu,
is_shared=False),
done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([30, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([30]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([30, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([30]),
device=cpu,
is_shared=False)
Again, our sample looks exactly the same as the data we gathered from the collector!
Next steps¶
You can have look at other multirpocessed collectors such as
MultiSyncDataCollector
orMultiaSyncDataCollector
.TorchRL also offers distributed collectors if you have multiple nodes to use for inference. Check them out in the API reference.
Check the dedicated Replay Buffer tutorial to know more about the options you have when building a buffer, or the API reference which covers all the features in details. Replay buffers have countless features such as multithreaded sampling, prioritized experience replay, and many more…
We left out the capacity of replay buffers to be iterated over for simplicity. Try it out for yourself: build a buffer and indicate its batch-size in the constructor, then try to iterate over it. This is equivalent to calling
rb.sample()
within a loop!
Total running time of the script: (0 minutes 22.051 seconds)
Estimated memory usage: 321 MB