• Docs >
  • Task-specific policy in multi-task environments
Shortcuts

Task-specific policy in multi-task environments

This tutorial details how multi-task policies and batched environments can be used.

At the end of this tutorial, you will be capable of writing policies that can compute actions in diverse settings using a distinct set of weights. You will also be able to execute diverse environments in parallel.

import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.envs import CatTensors, Compose, DoubleToFloat, ParallelEnv, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.modules import MLP

We design two environments, one humanoid that must complete the stand task and another that must learn to walk.

env1 = DMControlEnv("humanoid", "stand")
env1_obs_keys = list(env1.observation_spec.keys())
env1 = TransformedEnv(
    env1,
    Compose(
        CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
        CatTensors(env1_obs_keys, "observation"),
        DoubleToFloat(
            in_keys=["observation_stand", "observation"],
            in_keys_inv=["action"],
        ),
    ),
)
env2 = DMControlEnv("humanoid", "walk")
env2_obs_keys = list(env2.observation_spec.keys())
env2 = TransformedEnv(
    env2,
    Compose(
        CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
        CatTensors(env2_obs_keys, "observation"),
        DoubleToFloat(
            in_keys=["observation_walk", "observation"],
            in_keys_inv=["action"],
        ),
    ),
)
tdreset1 = env1.reset()
tdreset2 = env2.reset()

# In TorchRL, stacking is done in a lazy manner: the original tensordicts
# can still be recovered by indexing the main tensordict
tdreset = torch.stack([tdreset1, tdreset2], 0)
assert tdreset[0] is tdreset1
print(tdreset[0])
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False),
        observation_stand: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

Policy

We will design a policy where a backbone reads the “observation” key. Then specific sub-components will ready the “observation_stand” and “observation_walk” keys of the stacked tensordicts, if they are present, and pass them through the dedicated sub-network.

action_dim = env1.action_spec.shape[-1]
policy_common = TensorDictModule(
    nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"]
)
policy_stand = TensorDictModule(
    MLP(67 + 64, action_dim, depth=2),
    in_keys=["observation_stand", "hidden"],
    out_keys=["action"],
)
policy_walk = TensorDictModule(
    MLP(67 + 64, action_dim, depth=2),
    in_keys=["observation_walk", "hidden"],
    out_keys=["action"],
)
seq = TensorDictSequential(
    policy_common, policy_stand, policy_walk, partial_tolerant=True
)

Let’s check that our sequence outputs actions for a single env (stand).

seq(env1.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([21]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        hidden: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False),
        observation_stand: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

Let’s check that our sequence outputs actions for a single env (walk).

seq(env2.reset())
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([21]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        hidden: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False),
        observation_walk: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

This also works with the stack: now the stand and walk keys have disappeared, because they’re not shared by all tensordicts. But the TensorDictSequential still performed the operations. Note that the backbone was executed in a vectorized way - not in a loop - which is more efficient.

LazyStackedTensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 21]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        hidden: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([2, 67]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    exclusive_fields={
        0 ->
            observation_stand: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False),
        1 ->
            observation_walk: Tensor(shape=torch.Size([67]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False,
    stack_dim=0)

Executing diverse tasks in parallel

We can parallelize the operations if the common keys-value pairs share the same specs (in particular their shape and dtype must match: you can’t do the following if the observation shapes are different but are pointed to by the same key).

If ParallelEnv receives a single env making function, it will assume that a single task has to be performed. If a list of functions is provided, then it will assume that we are in a multi-task setting.

def env1_maker():
    return TransformedEnv(
        DMControlEnv("humanoid", "stand"),
        Compose(
            CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
            CatTensors(env1_obs_keys, "observation"),
            DoubleToFloat(
                in_keys=["observation_stand", "observation"],
                in_keys_inv=["action"],
            ),
        ),
    )


def env2_maker():
    return TransformedEnv(
        DMControlEnv("humanoid", "walk"),
        Compose(
            CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
            CatTensors(env2_obs_keys, "observation"),
            DoubleToFloat(
                in_keys=["observation_walk", "observation"],
                in_keys_inv=["action"],
            ),
        ),
    )


env = ParallelEnv(2, [env1_maker, env2_maker])
assert not env._single_task

tdreset = env.reset()
print(tdreset)
print(tdreset[0])
print(tdreset[1])  # should be different
Traceback (most recent call last):
  File "/pytorch/rl/docs/source/reference/generated/tutorials/multi_task.py", line 182, in <module>
    tdreset = env.reset()
  File "/pytorch/rl/torchrl/envs/common.py", line 2069, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/pytorch/rl/env/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/pytorch/rl/torchrl/envs/batched_envs.py", line 56, in decorated_fun
    self._start_workers()
  File "/pytorch/rl/torchrl/envs/batched_envs.py", line 1135, in _start_workers
    process.start()
  File "/pytorch/rl/env/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/pytorch/rl/env/lib/python3.8/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/pytorch/rl/env/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/pytorch/rl/env/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/pytorch/rl/env/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/pytorch/rl/env/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/pytorch/rl/env/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "/pytorch/rl/env/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 295, in reduce_tensor
    from torch.nested._internal.nested_tensor import NestedTensor
  File "/pytorch/rl/env/lib/python3.8/site-packages/torch/nested/_internal/nested_tensor.py", line 6, in <module>
    from torch.fx.experimental.symbolic_shapes import has_free_symbols
  File "/pytorch/rl/env/lib/python3.8/site-packages/torch/fx/experimental/symbolic_shapes.py", line 63, in <module>
    from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator
  File "/pytorch/rl/env/lib/python3.8/site-packages/torch/utils/_sympy/functions.py", line 1, in <module>
    import sympy
  File "/pytorch/rl/env/lib/python3.8/site-packages/sympy/__init__.py", line 52, in <module>
    from .core import (sympify, SympifyError, cacheit, Basic, Atom,
  File "/pytorch/rl/env/lib/python3.8/site-packages/sympy/core/__init__.py", line 9, in <module>
    from .expr import Expr, AtomicExpr, UnevaluatedExpr
  File "/pytorch/rl/env/lib/python3.8/site-packages/sympy/core/expr.py", line 4159, in <module>
    from .mul import Mul
  File "/pytorch/rl/env/lib/python3.8/site-packages/sympy/core/mul.py", line 2193, in <module>
    from .numbers import Rational
  File "/pytorch/rl/env/lib/python3.8/site-packages/sympy/core/numbers.py", line 4567, in <module>
    _sympy_converter[type(mpmath.rational.mpq(1, 2))] = sympify_mpmath_mpq
AttributeError: module 'mpmath' has no attribute 'rational'

Let’s pass the output through our network.

tdreset = seq(tdreset)
print(tdreset)
print(tdreset[0])
print(tdreset[1])  # should be different but all have an "action" key


env.step(tdreset)  # computes actions and execute steps in parallel
print(tdreset)
print(tdreset[0])
print(tdreset[1])  # next_observation has now been written

Rollout

td_rollout = env.rollout(100, policy=seq, return_contiguous=False)
td_rollout[:, 0]  # tensordict of the first step: only the common keys are shown
td_rollout[0]  # tensordict of the first env: the stand obs is present

env.close()
del env

Total running time of the script: (0 minutes 31.178 seconds)

Estimated memory usage: 564 MB

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources