• Docs >
  • Using pretrained models

Using pretrained models

This tutorial explains how to use pretrained models in TorchRL.

At the end of this tutorial, you will be capable of using pretrained models for efficient image representation, and fine-tune them.

TorchRL provides pretrained models that are to be used either as transforms or as components of the policy. As the sematic is the same, they can be used interchangeably in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601), but other models (e.g. VIP) will work equally well.

import torch.cuda
from tensordict.nn import TensorDictSequential
from torch import nn
from torchrl.envs import R3MTransform, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import Actor

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")

Let us first create an environment. For the sake of simplicity, we will be using a common gym environment. In practice, this will work in more challenging, embodied AI contexts (e.g. have a look at our Habitat wrappers).

base_env = GymEnv("Ant-v4", from_pixels=True, device=device)

Let us fetch our pretrained model. We ask for the pretrained version of the model through the download=True flag. By default this is turned off. Next, we will append our transform to the environment. In practice, what will happen is that each batch of data collected will go through the transform and be mapped on a “r3m_vec” entry in the output tensordict. Our policy, consisting of a single layer MLP, will then read this vector and compute the corresponding action.

r3m = R3MTransform(
env_transformed = TransformedEnv(base_env, r3m)
net = nn.Sequential(
    nn.LazyLinear(128, device=device),
    nn.Linear(128, base_env.action_spec.shape[-1], device=device),
policy = Actor(net, in_keys=["r3m_vec"])
Downloading: "https://pytorch.s3.amazonaws.com/models/rl/r3m/r3m_50.pt" to /root/.cache/torch/hub/checkpoints/r3m_50.pt

  0%|          | 0.00/374M [00:00<?, ?B/s]
  4%|▍         | 16.5M/374M [00:00<00:05, 68.4MB/s]
  8%|▊         | 31.2M/374M [00:00<00:04, 82.5MB/s]
 11%|█         | 39.5M/374M [00:00<00:04, 71.2MB/s]
 13%|█▎        | 47.6M/374M [00:00<00:05, 64.1MB/s]
 14%|█▍        | 53.8M/374M [00:00<00:05, 61.0MB/s]
 18%|█▊        | 65.6M/374M [00:01<00:05, 62.4MB/s]
 21%|██▏       | 80.4M/374M [00:01<00:04, 76.1MB/s]
 23%|██▎       | 87.8M/374M [00:01<00:04, 65.7MB/s]
 26%|██▌       | 97.8M/374M [00:01<00:04, 65.7MB/s]
 28%|██▊       | 104M/374M [00:01<00:04, 58.9MB/s]
 31%|███       | 115M/374M [00:01<00:04, 63.7MB/s]
 35%|███▍      | 130M/374M [00:01<00:03, 74.9MB/s]
 37%|███▋      | 138M/374M [00:02<00:04, 59.3MB/s]
 39%|███▉      | 148M/374M [00:02<00:03, 60.0MB/s]
 43%|████▎     | 162M/374M [00:02<00:02, 76.0MB/s]
 46%|████▌     | 170M/374M [00:02<00:02, 74.6MB/s]
 48%|████▊     | 179M/374M [00:02<00:02, 76.8MB/s]
 50%|████▉     | 186M/374M [00:02<00:03, 57.1MB/s]
 52%|█████▏    | 195M/374M [00:03<00:03, 57.9MB/s]
 54%|█████▍    | 201M/374M [00:03<00:03, 54.0MB/s]
 57%|█████▋    | 212M/374M [00:03<00:02, 59.5MB/s]
 58%|█████▊    | 218M/374M [00:03<00:02, 54.9MB/s]
 61%|██████▏   | 229M/374M [00:03<00:02, 57.0MB/s]
 66%|██████▌   | 246M/374M [00:03<00:01, 73.9MB/s]
 70%|██████▉   | 262M/374M [00:04<00:01, 77.9MB/s]
 72%|███████▏  | 269M/374M [00:04<00:01, 61.6MB/s]
 74%|███████▍  | 277M/374M [00:04<00:01, 61.1MB/s]
 76%|███████▌  | 283M/374M [00:04<00:01, 56.1MB/s]
 78%|███████▊  | 293M/374M [00:04<00:01, 48.2MB/s]
 80%|████████  | 300M/374M [00:05<00:01, 51.0MB/s]
 83%|████████▎ | 311M/374M [00:05<00:01, 54.7MB/s]
 87%|████████▋ | 326M/374M [00:05<00:00, 56.1MB/s]
 89%|████████▊ | 332M/374M [00:05<00:00, 46.8MB/s]
 92%|█████████▏| 342M/374M [00:05<00:00, 51.3MB/s]
 93%|█████████▎| 348M/374M [00:06<00:00, 44.4MB/s]
 96%|█████████▋| 360M/374M [00:06<00:00, 56.0MB/s]
100%|█████████▉| 373M/374M [00:06<00:00, 68.6MB/s]
100%|██████████| 374M/374M [00:06<00:00, 61.9MB/s]

Let’s check the number of parameters of the policy:

print("number of params:", len(list(policy.parameters())))
number of params: 4

We collect a rollout of 32 steps and print its output:

rollout = env_transformed.rollout(32, policy)
print("rollout with transform:", rollout)
rollout with transform: TensorDict(
        action: Tensor(shape=torch.Size([32, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
                done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                r3m_vec: Tensor(shape=torch.Size([32, 2048]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
        r3m_vec: Tensor(shape=torch.Size([32, 2048]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},

For fine tuning, we integrate the transform in the policy after making the parameters trainable. In practice, it may be wiser to restrict this to a subset of the parameters (say the last layer of the MLP).

policy = TensorDictSequential(r3m, policy)
print("number of params after r3m is integrated:", len(list(policy.parameters())))
number of params after r3m is integrated: 163

Again, we collect a rollout with R3M. The structure of the output has changed slightly, as now the environment returns pixels (and not an embedding). The embedding “r3m_vec” is an intermediate result of our policy.

rollout = base_env.rollout(32, policy)
print("rollout, fine tuning:", rollout)
Traceback (most recent call last):
  File "/pytorch/rl/docs/source/reference/generated/tutorials/pretrained_models.py", line 105, in <module>
    rollout = base_env.rollout(32, policy)
  File "/pytorch/rl/torchrl/envs/common.py", line 2565, in rollout
    tensordicts = self._rollout_stop_early(**kwargs)
  File "/pytorch/rl/torchrl/envs/common.py", line 2648, in _rollout_stop_early
    tensordict = self._step_mdp(tensordict)
  File "/pytorch/rl/torchrl/envs/utils.py", line 299, in __call__
    if self.validate(tensordict):
  File "/pytorch/rl/torchrl/envs/utils.py", line 196, in validate
    raise RuntimeError

The easiness with which we have swapped the transform from the env to the policy is due to the fact that both behave like TensorDictModule: they have a set of “in_keys” and “out_keys” that make it easy to read and write output in different context.

To conclude this tutorial, let’s have a look at how we could use R3M to read images stored in a replay buffer (e.g. in an offline RL context). First, let’s build our dataset:

from torchrl.data import LazyMemmapStorage, ReplayBuffer

storage = LazyMemmapStorage(1000)
rb = ReplayBuffer(storage=storage, transform=r3m)

We can now collect the data (random rollouts for our purpose) and fill the replay buffer with it:

total = 0
while total < 1000:
    tensordict = base_env.rollout(1000)
    total += tensordict.numel()

Let’s check what our replay buffer storage looks like. It should not contain the “r3m_vec” entry since we haven’t used it yet:

print("stored data:", storage._storage)

When sampling, the data will go through the R3M transform, giving us the processed data that we wanted. In this way, we can train an algorithm offline on a dataset made of images:

batch = rb.sample(32)
print("data after sampling:", batch)

Total running time of the script: (0 minutes 25.996 seconds)

Estimated memory usage: 939 MB

Gallery generated by Sphinx-Gallery


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources