Source code for torchrl.envs.common
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import abc
import functools
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key
from tensordict.utils import NestedKey
from torchrl._utils import (
_ends_with,
_make_ordinal_device,
_replace_last,
implement_for,
prod,
seed_generator,
)
from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
_repr_by_depth,
_StepMDP,
_terminated_or_truncated,
_update_during_reset,
get_available_libraries,
)
LIBRARIES = get_available_libraries()
def _tensor_to_np(t):
return t.detach().cpu().numpy()
dtype_map = {
torch.float: np.float32,
torch.double: np.float64,
torch.bool: bool,
}
[docs]class EnvMetaData:
"""A class for environment meta-data storage and passing in multiprocessed settings."""
def __init__(
self,
*,
tensordict: TensorDictBase,
specs: Composite,
batch_size: torch.Size,
env_str: str,
device: torch.device,
batch_locked: bool,
device_map: dict,
):
self.device = device
self.tensordict = tensordict
self.specs = specs
self.batch_size = batch_size
self.env_str = env_str
self.batch_locked = batch_locked
self.device_map = device_map
self.has_dynamic_specs = _has_dynamic_specs(specs)
@property
def tensordict(self):
return self._tensordict.to(self.device)
@property
def specs(self):
return self._specs.to(self.device)
@tensordict.setter
def tensordict(self, value: TensorDictBase):
self._tensordict = value.to("cpu")
@specs.setter
def specs(self, value: Composite):
self._specs = value.to("cpu")
@staticmethod
def metadata_from_env(env) -> EnvMetaData:
tensordict = env.fake_tensordict().clone()
for done_key in env.done_keys:
tensordict.set(
_replace_last(done_key, "_reset"),
torch.zeros_like(tensordict.get(("next", done_key))),
)
specs = env.specs.to("cpu")
batch_size = env.batch_size
env_str = str(env)
device = env.device
specs = specs.to("cpu")
batch_locked = env.batch_locked
# we need to save the device map, as the tensordict will be placed on cpu
device_map = {}
def fill_device_map(name, val, device_map=device_map):
device_map[name] = val.device
tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True)
return EnvMetaData(
tensordict=tensordict,
specs=specs,
batch_size=batch_size,
env_str=env_str,
device=device,
batch_locked=batch_locked,
device_map=device_map,
)
def expand(self, *size: int) -> EnvMetaData:
tensordict = self.tensordict.expand(*size).clone()
batch_size = torch.Size(list(size))
return EnvMetaData(
tensordict=tensordict,
specs=self.specs.expand(*size),
batch_size=batch_size,
env_str=self.env_str,
device=self.device,
batch_locked=self.batch_locked,
device_map=self.device_map,
)
def clone(self):
return EnvMetaData(
tensordict=self.tensordict.clone(),
specs=self.specs.clone(),
batch_size=torch.Size([*self.batch_size]),
env_str=deepcopy(self.env_str),
device=self.device,
batch_locked=self.batch_locked,
device_map=self.device_map,
)
def to(self, device: DEVICE_TYPING) -> EnvMetaData:
if device is not None:
device = _make_ordinal_device(torch.device(device))
device_map = {key: device for key in self.device_map}
tensordict = self.tensordict.contiguous().to(device)
specs = self.specs.to(device)
return EnvMetaData(
tensordict=tensordict,
specs=specs,
batch_size=self.batch_size,
env_str=self.env_str,
device=device,
batch_locked=self.batch_locked,
device_map=device_map,
)
class _EnvPostInit(abc.ABCMeta):
def __call__(cls, *args, **kwargs):
auto_reset = kwargs.pop("auto_reset", False)
auto_reset_replace = kwargs.pop("auto_reset_replace", True)
instance: EnvBase = super().__call__(*args, **kwargs)
# we create the done spec by adding a done/terminated entry if one is missing
instance._create_done_specs()
# we access lazy attributed to make sure they're built properly.
# This isn't done in `__init__` because we don't know if supre().__init__
# will be called before or after the specs, batch size etc are set.
_ = instance.done_spec
_ = instance.reward_spec
_ = instance.state_spec
if auto_reset:
from torchrl.envs.transforms.transforms import (
AutoResetEnv,
AutoResetTransform,
)
return AutoResetEnv(
instance, AutoResetTransform(replace=auto_reset_replace)
)
done_keys = set(instance.full_done_spec.keys(True, True))
obs_keys = set(instance.full_observation_spec.keys(True, True))
reward_keys = set(instance.full_reward_spec.keys(True, True))
# state_keys can match obs_keys so we don't test that
action_keys = set(instance.full_action_spec.keys(True, True))
state_keys = set(instance.full_state_spec.keys(True, True))
total_set = set()
for keyset in (done_keys, obs_keys, reward_keys):
if total_set.intersection(keyset):
raise RuntimeError(
f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
)
total_set = total_set.union(keyset)
total_set = set()
for keyset in (state_keys, action_keys):
if total_set.intersection(keyset):
raise RuntimeError(
f"The set of keys of one spec collides (culprit: {total_set.intersection(keyset)}) with another."
)
total_set = total_set.union(keyset)
return instance
[docs]class EnvBase(nn.Module, metaclass=_EnvPostInit):
"""Abstract environment parent class.
Keyword Args:
device (torch.device): The device of the environment. Deviceless environments
are allowed (device=None). If not ``None``, all specs will be cast
on that device and it is expected that all inputs and outputs will
live on that device.
Defaults to ``None``.
batch_size (torch.Size or equivalent, optional): batch-size of the environment.
Corresponds to the leading dimension of all the input and output
tensordicts the environment reads and writes. Defaults to an empty batch-size.
run_type_checks (bool, optional): If ``True``, type-checks will occur
at every reset and every step. Defaults to ``False``.
allow_done_after_reset (bool, optional): if ``True``, an environment can
be done after a call to :meth:`~.reset` is made. Defaults to ``False``.
Attributes:
done_spec (Composite): equivalent to ``full_done_spec`` as all
``done_specs`` contain at least a ``"done"`` and a ``"terminated"`` entry
action_spec (TensorSpec): the spec of the action. Links to the spec of the leaf
action if only one action tensor is to be expected. Otherwise links to
``full_action_spec``.
observation_spec (Composite): equivalent to ``full_observation_spec``.
reward_spec (TensorSpec): the spec of the reward. Links to the spec of the leaf
reward if only one reward tensor is to be expected. Otherwise links to
``full_reward_spec``.
state_spec (Composite): equivalent to ``full_state_spec``.
full_done_spec (Composite): a composite spec such that ``full_done_spec.zero()``
returns a tensordict containing only the leaves encoding the done status of the
environment.
full_action_spec (Composite): a composite spec such that ``full_action_spec.zero()``
returns a tensordict containing only the leaves encoding the action of the
environment.
full_observation_spec (Composite): a composite spec such that ``full_observation_spec.zero()``
returns a tensordict containing only the leaves encoding the observation of the
environment.
full_reward_spec (Composite): a composite spec such that ``full_reward_spec.zero()``
returns a tensordict containing only the leaves encoding the reward of the
environment.
full_state_spec (Composite): a composite spec such that ``full_state_spec.zero()``
returns a tensordict containing only the leaves encoding the inputs (actions
excluded) of the environment.
batch_size (torch.Size): The batch-size of the environment.
device (torch.device): the device where the input/outputs of the environment
are to be expected. Can be ``None``.
Methods:
step (TensorDictBase -> TensorDictBase): step in the environment
reset (TensorDictBase, optional -> TensorDictBase): reset the environment
set_seed (int -> int): sets the seed of the environment
rand_step (TensorDictBase, optional -> TensorDictBase): random step given the action spec
rollout (Callable, ... -> TensorDictBase): executes a rollout in the environment with the given policy (or random
steps if no policy is provided)
Examples:
>>> from torchrl.envs import EnvBase
>>> class CounterEnv(EnvBase):
... def __init__(self, batch_size=(), device=None, **kwargs):
... self.observation_spec = Composite(
... count=Unbounded(batch_size, device=device, dtype=torch.int64))
... self.action_spec = Unbounded(batch_size, device=device, dtype=torch.int8)
... # done spec and reward spec are set automatically
... def _step(self, tensordict):
...
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.batch_size # how many envs are run at once
torch.Size([])
>>> env.input_spec
Composite(
full_state_spec: None,
full_action_spec: Composite(
action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
>>> env.action_spec
BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
>>> env.observation_spec
Composite(
observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([]))
>>> env.reward_spec
UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous)
>>> env.done_spec
Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete)
>>> # the output_spec contains all the expected outputs
>>> env.output_spec
Composite(
full_reward_spec: Composite(
reward: UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
full_observation_spec: Composite(
observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
full_done_spec: Composite(
done: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
"""
def __init__(
self,
*,
device: DEVICE_TYPING = None,
batch_size: Optional[torch.Size] = None,
run_type_checks: bool = False,
allow_done_after_reset: bool = False,
):
self.__dict__.setdefault("_batch_size", None)
if device is not None:
self.__dict__["_device"] = _make_ordinal_device(torch.device(device))
output_spec = self.__dict__.get("_output_spec")
if output_spec is not None:
self.__dict__["_output_spec"] = (
output_spec.to(self.device)
if self.device is not None
else output_spec
)
input_spec = self.__dict__.get("_input_spec")
if input_spec is not None:
self.__dict__["_input_spec"] = (
input_spec.to(self.device)
if self.device is not None
else input_spec
)
super().__init__()
if "is_closed" not in self.__dir__():
self.is_closed = True
if batch_size is not None:
# we want an error to be raised if we pass batch_size but
# it's already been set
self.batch_size = torch.Size(batch_size)
self._run_type_checks = run_type_checks
self._allow_done_after_reset = allow_done_after_reset
@classmethod
def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
# inplace update will write tensors in-place on the provided tensordict.
# This is risky, especially if gradients need to be passed (in-place copy
# for tensors that are part of computational graphs will result in an error).
# It can also lead to inconsistencies when calling rollout.
cls._inplace_update = _inplace_update
cls._batch_locked = _batch_locked
cls._device = None
# cached in_keys to be excluded from update when calling step
cls._cache_in_keys = None
# We may assign _input_spec to the cls, but it must be assigned to the instance
# we pull it off, and place it back where it belongs
_input_spec = None
if hasattr(cls, "_input_spec"):
_input_spec = cls._input_spec.clone()
delattr(cls, "_input_spec")
_output_spec = None
if hasattr(cls, "_output_spec"):
_output_spec = cls._output_spec.clone()
delattr(cls, "_output_spec")
env = super().__new__(cls)
if _input_spec is not None:
env.__dict__["_input_spec"] = _input_spec
if _output_spec is not None:
env.__dict__["_output_spec"] = _output_spec
return env
return super().__new__(cls)
def __setattr__(self, key, value):
if key in (
"_input_spec",
"_observation_spec",
"_action_spec",
"_reward_spec",
"_output_spec",
"_state_spec",
"_done_spec",
):
raise AttributeError(
"To set an environment spec, please use `env.observation_spec = obs_spec` (without the leading"
" underscore)."
)
return super().__setattr__(key, value)
@property
def batch_locked(self) -> bool:
"""Whether the environment can be used with a batch size different from the one it was initialized with or not.
If True, the env needs to be used with a tensordict having the same batch size as the env.
batch_locked is an immutable property.
"""
return self._batch_locked
@batch_locked.setter
def batch_locked(self, value: bool) -> None:
raise RuntimeError("batch_locked is a read-only property")
@property
def run_type_checks(self) -> bool:
return self._run_type_checks
@run_type_checks.setter
def run_type_checks(self, run_type_checks: bool) -> None:
self._run_type_checks = run_type_checks
@property
def batch_size(self) -> torch.Size:
"""Number of envs batched in this environment instance organised in a `torch.Size()` object.
Environment may be similar or different but it is assumed that they have little if
not no interactions between them (e.g., multi-task or batched execution
in parallel).
"""
_batch_size = self.__dict__["_batch_size"]
if _batch_size is None:
_batch_size = self._batch_size = torch.Size([])
return _batch_size
@batch_size.setter
def batch_size(self, value: torch.Size) -> None:
self._batch_size = torch.Size(value)
if (
hasattr(self, "output_spec")
and self.output_spec.shape[: len(value)] != value
):
self.output_spec.unlock_()
self.output_spec.shape = value
self.output_spec.lock_()
if hasattr(self, "input_spec") and self.input_spec.shape[: len(value)] != value:
self.input_spec.unlock_()
self.input_spec.shape = value
self.input_spec.lock_()
@property
def shape(self):
"""Equivalent to :attr:`~.batch_size`."""
return self.batch_size
@property
def device(self) -> torch.device:
device = self.__dict__.get("_device")
return device
@device.setter
def device(self, value: torch.device) -> None:
device = self.__dict__.get("_device")
if device is None:
self.__dict__["_device"] = value
return
raise RuntimeError("device cannot be set. Call env.to(device) instead.")
def ndimension(self):
return len(self.batch_size)
@property
def ndim(self):
return self.ndimension()
[docs] def append_transform(
self,
transform: "Transform" # noqa: F821
| Callable[[TensorDictBase], TensorDictBase],
) -> EnvBase:
"""Returns a transformed environment where the callable/transform passed is applied.
Args:
transform (Transform or Callable[[TensorDictBase], TensorDictBase]): the transform to apply
to the environment.
Examples:
>>> from torchrl.envs import GymEnv
>>> import torch
>>> env = GymEnv("CartPole-v1")
>>> loc = 0.5
>>> scale = 1.0
>>> transform = lambda data: data.set("observation", (data.get("observation") - loc)/scale)
>>> env = env.append_transform(transform=transform)
>>> print(env)
TransformedEnv(
env=GymEnv(env=CartPole-v1, batch_size=torch.Size([]), device=cpu),
transform=_CallableTransform(keys=[]))
"""
from torchrl.envs.transforms.transforms import TransformedEnv
return TransformedEnv(self, transform)
# Parent specs: input and output spec.
@property
def input_spec(self) -> TensorSpec:
"""Input spec.
The composite spec containing all specs for data input to the environments.
It contains:
- "full_action_spec": the spec of the input actions
- "full_state_spec": the spec of all other environment inputs
This attibute is locked and should be read-only.
Instead, to set the specs contained in it, use the respective properties.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.input_spec
Composite(
full_state_spec: None,
full_action_spec: Composite(
action: BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
"""
input_spec = self.__dict__.get("_input_spec")
if input_spec is None:
input_spec = Composite(
full_state_spec=None,
shape=self.batch_size,
device=self.device,
).lock_()
self.__dict__["_input_spec"] = input_spec
return input_spec
@input_spec.setter
def input_spec(self, value: TensorSpec) -> None:
raise RuntimeError("input_spec is protected.")
@property
def output_spec(self) -> TensorSpec:
"""Output spec.
The composite spec containing all specs for data output from the environments.
It contains:
- "full_reward_spec": the spec of reward
- "full_done_spec": the spec of done
- "full_observation_spec": the spec of all other environment outputs
This attibute is locked and should be read-only.
Instead, to set the specs contained in it, use the respective properties.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.output_spec
Composite(
full_reward_spec: Composite(
reward: UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
full_observation_spec: Composite(
observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([])),
full_done_spec: Composite(
done: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
"""
output_spec = self.__dict__.get("_output_spec")
if output_spec is None:
output_spec = Composite(
shape=self.batch_size,
device=self.device,
).lock_()
self.__dict__["_output_spec"] = output_spec
return output_spec
@output_spec.setter
def output_spec(self, value: TensorSpec) -> None:
raise RuntimeError("output_spec is protected.")
@property
def action_keys(self) -> List[NestedKey]:
"""The action keys of an environment.
By default, there will only be one key named "action".
Keys are sorted by depth in the data tree.
"""
action_keys = self.__dict__.get("_action_keys")
if action_keys is not None:
return action_keys
keys = self.input_spec["full_action_spec"].keys(True, True)
if not len(keys):
raise AttributeError("Could not find action spec")
keys = sorted(keys, key=_repr_by_depth)
self.__dict__["_action_keys"] = keys
return keys
@property
def state_keys(self) -> List[NestedKey]:
"""The state keys of an environment.
By default, there will only be one key named "state".
Keys are sorted by depth in the data tree.
"""
state_keys = self.__dict__.get("_state_keys")
if state_keys is not None:
return state_keys
keys = self.input_spec["full_state_spec"].keys(True, True)
keys = sorted(keys, key=_repr_by_depth)
self.__dict__["_state_keys"] = keys
return keys
@property
def action_key(self) -> NestedKey:
"""The action key of an environment.
By default, this will be "action".
If there is more than one action key in the environment, this function will raise an exception.
"""
if len(self.action_keys) > 1:
raise KeyError(
"action_key requested but more than one key present in the environment"
)
return self.action_keys[0]
# Action spec: action specs belong to input_spec
@property
def action_spec(self) -> TensorSpec:
"""The ``action`` spec.
The ``action_spec`` is always stored as a composite spec.
If the action spec is provided as a simple spec, this will be returned.
>>> env.action_spec = Unbounded(1)
>>> env.action_spec
UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
If the action spec is provided as a composite spec and contains only one leaf,
this function will return just the leaf.
>>> env.action_spec = Composite({"nested": {"action": Unbounded(1)}})
>>> env.action_spec
UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
If the action spec is provided as a composite spec and has more than one leaf,
this function will return the whole spec.
>>> env.action_spec = Composite({"nested": {"action": Unbounded(1), "another_action": Categorical(1)}})
>>> env.action_spec
Composite(
nested: Composite(
action: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
another_action: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=1),
device=cpu,
dtype=torch.int64,
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
To retrieve the full spec passed, use:
>>> env.input_spec["full_action_spec"]
This property is mutable.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.action_spec
BoundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
"""
try:
action_spec = self.input_spec["full_action_spec"]
except (KeyError, AttributeError):
raise KeyError("Failed to find the action_spec.")
if len(self.action_keys) > 1:
out = action_spec
else:
try:
out = action_spec[self.action_key]
except KeyError:
# the key may have changed
raise KeyError(
"The action_key attribute seems to have changed. "
"This occurs when a action_spec is updated without "
"calling `env.action_spec = new_spec`. "
"Make sure you rely on this type of command "
"to set the action and other specs."
)
return out
@action_spec.setter
def action_spec(self, value: TensorSpec) -> None:
try:
self.input_spec.unlock_()
device = self.input_spec._device
try:
delattr(self, "_action_keys")
except AttributeError:
pass
if not hasattr(value, "shape"):
raise TypeError(
f"action_spec of type {type(value)} do not have a shape attribute."
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
"An empty Composite was passed for the action spec. "
"This is currently not permitted."
)
else:
value = Composite(
action=value.to(device), shape=self.batch_size, device=device
)
self.input_spec["full_action_spec"] = value.to(device)
finally:
self.input_spec.lock_()
@property
def full_action_spec(self) -> Composite:
"""The full action spec.
``full_action_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the action entries.
Examples:
>>> from torchrl.envs import BraxEnv
>>> for envname in BraxEnv.available_envs:
... break
>>> env = BraxEnv(envname)
>>> env.full_action_spec
Composite(
action: BoundedContinuous(
shape=torch.Size([8]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([]))
"""
full_action_spec = self.input_spec.get("full_action_spec", None)
if full_action_spec is None:
full_action_spec = Composite(shape=self.batch_size, device=self.device)
self.input_spec.unlock_()
self.input_spec["full_action_spec"] = full_action_spec
self.input_spec.lock_()
return full_action_spec
@full_action_spec.setter
def full_action_spec(self, spec: Composite) -> None:
self.action_spec = spec
# Reward spec
@property
def reward_keys(self) -> List[NestedKey]:
"""The reward keys of an environment.
By default, there will only be one key named "reward".
Keys are sorted by depth in the data tree.
"""
reward_keys = self.__dict__.get("_reward_keys")
if reward_keys is not None:
return reward_keys
reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth)
self.__dict__["_reward_keys"] = reward_keys
return reward_keys
@property
def reward_key(self):
"""The reward key of an environment.
By default, this will be "reward".
If there is more than one reward key in the environment, this function will raise an exception.
"""
if len(self.reward_keys) > 1:
raise KeyError(
"reward_key requested but more than one key present in the environment"
)
return self.reward_keys[0]
# Reward spec: reward specs belong to output_spec
@property
def reward_spec(self) -> TensorSpec:
"""The ``reward`` spec.
The ``reward_spec`` is always stored as a composite spec.
If the reward spec is provided as a simple spec, this will be returned.
>>> env.reward_spec = Unbounded(1)
>>> env.reward_spec
UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
If the reward spec is provided as a composite spec and contains only one leaf,
this function will return just the leaf.
>>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1)}})
>>> env.reward_spec
UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous)
If the reward spec is provided as a composite spec and has more than one leaf,
this function will return the whole spec.
>>> env.reward_spec = Composite({"nested": {"reward": Unbounded(1), "another_reward": Categorical(1)}})
>>> env.reward_spec
Composite(
nested: Composite(
reward: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous),
another_reward: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=1),
device=cpu,
dtype=torch.int64,
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
To retrieve the full spec passed, use:
>>> env.output_spec["full_reward_spec"]
This property is mutable.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.reward_spec
UnboundedContinuous(
shape=torch.Size([1]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous)
"""
try:
reward_spec = self.output_spec["full_reward_spec"]
except (KeyError, AttributeError):
# populate the "reward" entry
# this will be raised if there is not full_reward_spec (unlikely) or no reward_key
# Since output_spec is lazily populated with an empty composite spec for
# reward_spec, the second case is much more likely to occur.
self.reward_spec = Unbounded(
shape=(*self.batch_size, 1),
device=self.device,
)
reward_spec = self.output_spec["full_reward_spec"]
reward_keys = self.reward_keys
if len(reward_keys) > 1 or not len(reward_keys):
return reward_spec
else:
return reward_spec[self.reward_keys[0]]
@reward_spec.setter
def reward_spec(self, value: TensorSpec) -> None:
try:
self.output_spec.unlock_()
device = self.output_spec._device
try:
delattr(self, "_reward_keys")
except AttributeError:
pass
if not hasattr(value, "shape"):
raise TypeError(
f"reward_spec of type {type(value)} do not have a shape "
f"attribute."
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
"An empty Composite was passed for the reward spec. "
"This is currently not permitted."
)
else:
value = Composite(
reward=value.to(device), shape=self.batch_size, device=device
)
for leaf in value.values(True, True):
if len(leaf.shape) == 0:
raise RuntimeError(
"the reward_spec's leafs shape cannot be empty (this error"
" usually comes from trying to set a reward_spec"
" with a null number of dimensions. Try using a multidimensional"
" spec instead, for instance with a singleton dimension at the tail)."
)
self.output_spec["full_reward_spec"] = value.to(device)
finally:
self.output_spec.lock_()
@property
def full_reward_spec(self) -> Composite:
"""The full reward spec.
``full_reward_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the reward entries.
Examples:
>>> import gymnasium
>>> from torchrl.envs import GymWrapper, TransformedEnv, RenameTransform
>>> base_env = GymWrapper(gymnasium.make("Pendulum-v1"))
>>> env = TransformedEnv(base_env, RenameTransform("reward", ("nested", "reward")))
>>> env.full_reward_spec
Composite(
nested: Composite(
reward: UnboundedContinuous(
shape=torch.Size([1]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=None, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
"""
return self.output_spec["full_reward_spec"]
@full_reward_spec.setter
def full_reward_spec(self, spec: Composite) -> None:
self.reward_spec = spec.to(self.device) if self.device is not None else spec
# done spec
@property
def done_keys(self) -> List[NestedKey]:
"""The done keys of an environment.
By default, there will only be one key named "done".
Keys are sorted by depth in the data tree.
"""
done_keys = self.__dict__.get("_done_keys")
if done_keys is not None:
return done_keys
done_keys = sorted(self.full_done_spec.keys(True, True), key=_repr_by_depth)
self.__dict__["_done_keys"] = done_keys
return done_keys
@property
def done_key(self):
"""The done key of an environment.
By default, this will be "done".
If there is more than one done key in the environment, this function will raise an exception.
"""
if len(self.done_keys) > 1:
raise KeyError(
"done_key requested but more than one key present in the environment"
)
return self.done_keys[0]
@property
def full_done_spec(self) -> Composite:
"""The full done spec.
``full_done_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the done entries.
It can be used to generate fake data with a structure that mimics the
one obtained at runtime.
Examples:
>>> import gymnasium
>>> from torchrl.envs import GymWrapper
>>> env = GymWrapper(gymnasium.make("Pendulum-v1"))
>>> env.full_done_spec
Composite(
done: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
truncated: Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete), device=cpu, shape=torch.Size([]))
"""
return self.output_spec["full_done_spec"]
@full_done_spec.setter
def full_done_spec(self, spec: Composite) -> None:
self.done_spec = spec.to(self.device) if self.device is not None else spec
# Done spec: done specs belong to output_spec
@property
def done_spec(self) -> TensorSpec:
"""The ``done`` spec.
The ``done_spec`` is always stored as a composite spec.
If the done spec is provided as a simple spec, this will be returned.
>>> env.done_spec = Categorical(2, dtype=torch.bool)
>>> env.done_spec
Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete)
If the done spec is provided as a composite spec and contains only one leaf,
this function will return just the leaf.
>>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool)}})
>>> env.done_spec
Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete)
If the done spec is provided as a composite spec and has more than one leaf,
this function will return the whole spec.
>>> env.done_spec = Composite({"nested": {"done": Categorical(2, dtype=torch.bool), "another_done": Categorical(2, dtype=torch.bool)}})
>>> env.done_spec
Composite(
nested: Composite(
done: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete),
another_done: Categorical(
shape=torch.Size([]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
To always retrieve the full spec passed, use:
>>> env.output_spec["full_done_spec"]
This property is mutable.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.done_spec
Categorical(
shape=torch.Size([1]),
space=DiscreteBox(n=2),
device=cpu,
dtype=torch.bool,
domain=discrete)
"""
done_spec = self.output_spec["full_done_spec"]
return done_spec
def _create_done_specs(self):
"""Reads through the done specs and makes it so that it's complete.
If the done_specs contain only a ``"done"`` entry, a similar ``"terminated"`` entry is created.
Same goes if only ``"terminated"`` key is present.
If none of ``"done"`` and ``"terminated"`` can be found and the spec is not
empty, nothing is changed.
"""
try:
full_done_spec = self.output_spec["full_done_spec"]
except KeyError:
full_done_spec = Composite(
shape=self.output_spec.shape, device=self.output_spec.device
)
full_done_spec["done"] = Categorical(
n=2,
shape=(*full_done_spec.shape, 1),
dtype=torch.bool,
device=self.device,
)
full_done_spec["terminated"] = Categorical(
n=2,
shape=(*full_done_spec.shape, 1),
dtype=torch.bool,
device=self.device,
)
self.output_spec.unlock_()
self.output_spec["full_done_spec"] = full_done_spec
self.output_spec.lock_()
return
def check_local_done(spec):
shape = None
for key, item in list(
spec.items()
): # list to avoid error due to in-loop changes
# in the case where the spec is non-empty and there is no done and no terminated, we do nothing
if key == "done" and "terminated" not in spec.keys():
spec["terminated"] = item.clone()
elif key == "terminated" and "done" not in spec.keys():
spec["done"] = item.clone()
elif isinstance(item, Composite):
check_local_done(item)
else:
if shape is None:
shape = item.shape
continue
# checks that all shape match
if shape != item.shape:
raise ValueError(
f"All shapes should match in done_spec {spec} (shape={shape}, key={key})."
)
# if the spec is empty, we need to add a done and terminated manually
if spec.is_empty():
spec["done"] = Categorical(
n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
)
spec["terminated"] = Categorical(
n=2, shape=(*spec.shape, 1), dtype=torch.bool, device=self.device
)
self.output_spec.unlock_()
check_local_done(full_done_spec)
self.output_spec["full_done_spec"] = full_done_spec
self.output_spec.lock_()
return
@done_spec.setter
def done_spec(self, value: TensorSpec) -> None:
try:
self.output_spec.unlock_()
device = self.output_spec.device
try:
delattr(self, "_done_keys")
except AttributeError:
pass
if not hasattr(value, "shape"):
raise TypeError(
f"done_spec of type {type(value)} do not have a shape "
f"attribute."
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if isinstance(value, Composite):
for _ in value.values(True, True): # noqa: B007
break
else:
raise RuntimeError(
"An empty Composite was passed for the done spec. "
"This is currently not permitted."
)
else:
value = Composite(
done=value.to(device),
terminated=value.to(device),
shape=self.batch_size,
device=device,
)
for leaf in value.values(True, True):
if len(leaf.shape) == 0:
raise RuntimeError(
"the done_spec's leafs shape cannot be empty (this error"
" usually comes from trying to set a reward_spec"
" with a null number of dimensions. Try using a multidimensional"
" spec instead, for instance with a singleton dimension at the tail)."
)
self.output_spec["full_done_spec"] = value.to(device)
self._create_done_specs()
finally:
self.output_spec.lock_()
# observation spec: observation specs belong to output_spec
@property
def observation_spec(self) -> Composite:
"""Observation spec.
Must be a :class:`torchrl.data.Composite` instance.
The keys listed in the spec are directly accessible after reset and step.
In TorchRL, even though they are not properly speaking "observations"
all info, states, results of transforms etc. outputs from the environment are stored in the
``observation_spec``.
Therefore, ``"observation_spec"`` should be thought as
a generic data container for environment outputs that are not done or reward data.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = GymEnv("Pendulum-v1")
>>> env.observation_spec
Composite(
observation: BoundedContinuous(
shape=torch.Size([3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([]))
"""
observation_spec = self.output_spec.get("full_observation_spec", default=None)
if observation_spec is None:
observation_spec = Composite(shape=self.batch_size, device=self.device)
self.output_spec.unlock_()
self.output_spec["full_observation_spec"] = observation_spec
self.output_spec.lock_()
return observation_spec
@observation_spec.setter
def observation_spec(self, value: TensorSpec) -> None:
try:
self.output_spec.unlock_()
if not isinstance(value, Composite):
raise TypeError("The type of an observation_spec must be Composite.")
elif value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
device = self.output_spec._device
self.output_spec["full_observation_spec"] = (
value.to(device) if device is not None else value
)
finally:
self.output_spec.lock_()
@property
def full_observation_spec(self) -> Composite:
return self.observation_spec
@full_observation_spec.setter
def full_observation_spec(self, spec: Composite):
self.observation_spec = spec
# state spec: state specs belong to input_spec
@property
def state_spec(self) -> Composite:
"""State spec.
Must be a :class:`torchrl.data.Composite` instance.
The keys listed here should be provided as input alongside actions to the environment.
In TorchRL, even though they are not properly speaking "state"
all inputs to the environment that are not actions are stored in the
``state_spec``.
Therefore, ``"state_spec"`` should be thought as
a generic data container for environment inputs that are not action data.
Examples:
>>> from torchrl.envs import BraxEnv
>>> for envname in BraxEnv.available_envs:
... break
>>> env = BraxEnv(envname)
>>> env.state_spec
Composite(
state: Composite(
pipeline_state: Composite(
q: UnboundedContinuous(
shape=torch.Size([15]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous),
[...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
"""
state_spec = self.input_spec["full_state_spec"]
if state_spec is None:
state_spec = Composite(shape=self.batch_size, device=self.device)
self.input_spec.unlock_()
self.input_spec["full_state_spec"] = state_spec
self.input_spec.lock_()
return state_spec
@state_spec.setter
def state_spec(self, value: Composite) -> None:
try:
self.input_spec.unlock_()
try:
delattr(self, "_state_keys")
except AttributeError:
pass
if value is None:
self.input_spec["full_state_spec"] = Composite(
device=self.device, shape=self.batch_size
)
else:
device = self.input_spec.device
if not isinstance(value, Composite):
raise TypeError("The type of an state_spec must be Composite.")
elif value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
if value.shape[: len(self.batch_size)] != self.batch_size:
raise ValueError(
f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})."
)
self.input_spec["full_state_spec"] = (
value.to(device) if device is not None else value
)
finally:
self.input_spec.lock_()
@property
def full_state_spec(self) -> Composite:
"""The full state spec.
``full_state_spec`` is a :class:`~torchrl.data.Composite`` instance
that contains all the state entries (ie, the input data that is not action).
Examples:
>>> from torchrl.envs import BraxEnv
>>> for envname in BraxEnv.available_envs:
... break
>>> env = BraxEnv(envname)
>>> env.full_state_spec
Composite(
state: Composite(
pipeline_state: Composite(
q: UnboundedContinuous(
shape=torch.Size([15]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous),
[...], device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
"""
return self.state_spec
@full_state_spec.setter
def full_state_spec(self, spec: Composite) -> None:
self.state_spec = spec
# Single-env specs can be used to remove the batch size from the spec
@property
def batch_dims(self) -> int:
"""Number of batch dimensions of the env."""
return len(self.batch_size)
def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
if not self.batch_dims:
return spec
idx = tuple(0 for _ in range(self.batch_dims))
return spec[idx]
@property
def single_full_action_spec(self) -> Composite:
"""Returns the action spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)
@property
def single_action_spec(self) -> TensorSpec:
"""Returns the action spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.action_spec)
@property
def single_full_observation_spec(self) -> Composite:
"""Returns the observation spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)
@property
def single_observation_spec(self) -> Composite:
"""Returns the observation spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.observation_spec)
@property
def single_full_reward_spec(self) -> Composite:
"""Returns the reward spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)
@property
def single_reward_spec(self) -> TensorSpec:
"""Returns the reward spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.reward_spec)
@property
def single_full_done_spec(self) -> Composite:
"""Returns the done spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)
@property
def single_done_spec(self) -> TensorSpec:
"""Returns the done spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.done_spec)
@property
def single_output_spec(self) -> Composite:
"""Returns the output spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.output_spec)
@property
def single_input_spec(self) -> Composite:
"""Returns the input spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.input_spec)
@property
def single_full_state_spec(self) -> Composite:
"""Returns the state spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_state_spec)
@property
def single_state_spec(self) -> TensorSpec:
"""Returns the state spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.state_spec)
[docs] def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Makes a step in the environment.
Step accepts a single argument, tensordict, which usually carries an 'action' key which indicates the action
to be taken.
Step will call an out-place private method, _step, which is the method to be re-written by EnvBase subclasses.
Args:
tensordict (TensorDictBase): Tensordict containing the action to be taken.
If the input tensordict contains a ``"next"`` entry, the values contained in it
will prevail over the newly computed values. This gives a mechanism
to override the underlying computations.
Returns:
the input tensordict, modified in place with the resulting observations, done state and reward
(+ others if needed).
"""
# sanity check
self._assert_tensordict_shape(tensordict)
partial_steps = None
if not self.batch_locked:
# Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
partial_steps = tensordict.get("_step", None)
if partial_steps is not None:
if partial_steps.all():
partial_steps = None
else:
tensordict_batch_size = tensordict.batch_size
partial_steps = partial_steps.view(tensordict_batch_size)
tensordict = tensordict[partial_steps]
else:
tensordict_batch_size = self.batch_size
next_preset = tensordict.get("next", None)
next_tensordict = self._step(tensordict)
next_tensordict = self._step_proc_data(next_tensordict)
if next_preset is not None:
# tensordict could already have a "next" key
# this could be done more efficiently by not excluding but just passing
# the necessary keys
next_tensordict.update(
next_preset.exclude(*next_tensordict.keys(True, True))
)
tensordict.set("next", next_tensordict)
if partial_steps is not None:
result = tensordict.new_zeros(tensordict_batch_size)
result[partial_steps] = tensordict
return result
return tensordict
@classmethod
def _complete_done(
cls, done_spec: Composite, data: TensorDictBase
) -> TensorDictBase:
"""Completes the data structure at step time to put missing done keys."""
# by default, if a done key is missing, it is assumed that it is False
# except in 2 cases: (1) there is a "done" but no "terminated" or (2)
# there is a "terminated" but no "done".
if done_spec.ndim:
leading_dim = data.shape[: -done_spec.ndim]
else:
leading_dim = data.shape
vals = {}
i = -1
for i, (key, item) in enumerate(done_spec.items()): # noqa: B007
val = data.get(key, None)
if isinstance(item, Composite):
if val is not None:
cls._complete_done(item, val)
continue
shape = (*leading_dim, *item.shape)
if val is not None:
if val.shape != shape:
val = val.reshape(shape)
data.set(key, val)
vals[key] = val
if len(vals) < i + 1:
# complete missing dones: we only want to do that if we don't have enough done values
data_keys = set(data.keys())
done_spec_keys = set(done_spec.keys())
for key, item in done_spec.items(False, True):
val = vals.get(key, None)
if (
key == "done"
and val is not None
and "terminated" in done_spec_keys
and "terminated" not in data_keys
):
if "truncated" in data_keys:
raise RuntimeError(
"Cannot infer the value of terminated when only done and truncated are present."
)
data.set("terminated", val)
data_keys.add("terminated")
elif (
key == "terminated"
and val is not None
and "done" in done_spec_keys
and "done" not in data_keys
):
if "truncated" in data_keys:
val = val | data.get("truncated")
data.set("done", val)
data_keys.add("done")
elif val is None and key not in data_keys:
# we must keep this here: we only want to fill with 0s if we're sure
# done should not be copied to terminated or terminated to done
# in this case, just fill with 0s
data.set(key, item.zero(leading_dim))
return data
def _step_proc_data(self, next_tensordict_out):
batch_size = self.batch_size
dims = len(batch_size)
leading_batch_size = (
next_tensordict_out.batch_size[:-dims]
if dims
else next_tensordict_out.shape
)
for reward_key in self.reward_keys:
reward = next_tensordict_out.get(reward_key)
expected_reward_shape = torch.Size(
[
*leading_batch_size,
*self.output_spec["full_reward_spec"][reward_key].shape,
]
)
actual_reward_shape = reward.shape
if actual_reward_shape != expected_reward_shape:
reward = reward.view(expected_reward_shape)
next_tensordict_out.set(reward_key, reward)
self._complete_done(self.full_done_spec, next_tensordict_out)
if self.run_type_checks:
for key, spec in self.observation_spec.items():
obs = next_tensordict_out.get(key)
spec.type_check(obs)
for reward_key in self.reward_keys:
if (
next_tensordict_out.get(reward_key).dtype
is not self.output_spec[
unravel_key(("full_reward_spec", reward_key))
].dtype
):
raise TypeError(
f"expected reward.dtype to be {self.output_spec[unravel_key(('full_reward_spec',reward_key))]} "
f"but got {next_tensordict_out.get(reward_key).dtype}"
)
for done_key in self.done_keys:
if (
next_tensordict_out.get(done_key).dtype
is not self.output_spec["full_done_spec", done_key].dtype
):
raise TypeError(
f"expected done.dtype to be {self.output_spec['full_done_spec', done_key].dtype} but got {next_tensordict_out.get(done_key).dtype}"
)
return next_tensordict_out
def _get_in_keys_to_exclude(self, tensordict):
if self._cache_in_keys is None:
self._cache_in_keys = list(
set(self.input_spec.keys(True)).intersection(
tensordict.keys(True, True)
)
)
return self._cache_in_keys
[docs] @classmethod
def register_gym(
cls,
id: str,
*,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
backend: str = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
"""Registers an environment in gym(nasium).
This method is designed with the following scopes in mind:
- Incorporate a TorchRL-first environment in a framework that uses Gym;
- Incorporate another environment (eg, DeepMind Control, Brax, Jumanji, ...)
in a framework that uses Gym.
Args:
id (str): the name of the environment. Should follow the
`gym naming convention <https://www.gymlibrary.dev/content/environment_creation/#registering-envs>`_.
Keyword Args:
entry_point (callable, optional): the entry point to build the environment.
If none is passed, the parent class will be used as entry point.
Typically, this is used to register an environment that does not
necessarily inherit from the base being used:
>>> from torchrl.envs import DMControlEnv
>>> DMControlEnv.register_gym("DMC-cheetah-v0", env_name="cheetah", task="run")
>>> # equivalently
>>> EnvBase.register_gym("DMC-cheetah-v0", entry_point=DMControlEnv, env_name="cheetah", task="run")
transform (torchrl.envs.Transform): a transform (or list of transforms
within a :class:`torchrl.envs.Compose` instance) to be used with the env.
This arg can be passed during a call to :func:`~gym.make` (see
example below).
info_keys (List[NestedKey], optional): if provided, these keys will
be used to build the info dictionary and will be excluded from
the observation keys.
This arg can be passed during a call to :func:`~gym.make` (see
example below).
.. warning::
It may be the case that using ``info_keys`` makes a spec empty
because the content has been moved to the info dictionary.
Gym does not like empty ``Dict`` in the specs, so this empty
content should be removed with :class:`~torchrl.envs.transforms.RemoveEmptySpecs`.
backend (str, optional): the backend. Can be either `"gym"` or `"gymnasium"`
or any other backend compatible with :class:`~torchrl.envs.libs.gym.set_gym_backend`.
to_numpy (bool, optional): if ``True``, the result of calls to `step` and
`reset` will be mapped to numpy arrays. Defaults to ``False``
(results are tensors).
This arg can be passed during a call to :func:`~gym.make` (see
example below).
reward_threshold (:obj:`float`, optional): [Gym kwarg] The reward threshold
considered to have learnt an environment.
nondeterministic (bool, optional): [Gym kwarg If the environment is nondeterministic
(even with knowledge of the initial seed and all actions). Defaults to
``False``.
max_episode_steps (int, optional): [Gym kwarg] The maximum number
of episodes steps before truncation. Used by the Time Limit wrapper.
order_enforce (bool, optional): [Gym >= 0.14] Whether the order
enforcer wrapper should be applied to ensure users run functions
in the correct order.
Defaults to ``True``.
autoreset (bool, optional): [Gym >= 0.14] Whether the autoreset wrapper
should be added such that reset does not need to be called.
Defaults to ``False``.
disable_env_checker: [Gym >= 0.14] Whether the environment
checker should be disabled for the environment. Defaults to ``False``.
apply_api_compatibility: [Gym >= 0.26] If to apply the `StepAPICompatibility` wrapper.
Defaults to ``False``.
**kwargs: arbitrary keyword arguments which are passed to the environment constructor.
.. note::
TorchRL's environment do not have the concept of an ``"info"`` dictionary,
as ``TensorDict`` offers all the storage requirements deemed necessary
in most training settings. Still, you can use the ``info_keys`` argument to
have a fine grained control over what is deemed to be considered
as an observation and what should be seen as info.
Examples:
>>> # Register the "cheetah" env from DMControl with the "run" task
>>> from torchrl.envs import DMControlEnv
>>> import torch
>>> DMControlEnv.register_gym("DMC-cheetah-v0", to_numpy=False, backend="gym", env_name="cheetah", task_name="run")
>>> import gym
>>> envgym = gym.make("DMC-cheetah-v0")
>>> envgym.seed(0)
>>> torch.manual_seed(0)
>>> envgym.reset()
({'position': tensor([-0.0855, 0.0215, -0.0881, -0.0412, -0.1101, 0.0080, 0.0254, 0.0424],
dtype=torch.float64), 'velocity': tensor([ 1.9609e-02, -1.9776e-04, -1.6347e-03, 3.3842e-02, 2.5338e-02,
3.3064e-02, 1.0381e-04, 7.6656e-05, 1.0204e-02],
dtype=torch.float64)}, {})
>>> envgym.step(envgym.action_space.sample())
({'position': tensor([-0.0833, 0.0275, -0.0612, -0.0770, -0.1256, 0.0082, 0.0186, 0.0476],
dtype=torch.float64), 'velocity': tensor([ 0.2221, 0.2256, 0.5930, 2.6937, -3.5865, -1.5479, 0.0187, -0.6825,
0.5224], dtype=torch.float64)}, tensor([0.0018], dtype=torch.float64), tensor([False]), tensor([False]), {})
>>> # same environment with observation stacked
>>> from torchrl.envs import CatTensors
>>> envgym = gym.make("DMC-cheetah-v0", transform=CatTensors(in_keys=["position", "velocity"], out_key="observation"))
>>> envgym.reset()
({'observation': tensor([-0.1005, 0.0335, -0.0268, 0.0133, -0.0627, 0.0074, -0.0488, -0.0353,
-0.0075, -0.0069, 0.0098, -0.0058, 0.0033, -0.0157, -0.0004, -0.0381,
-0.0452], dtype=torch.float64)}, {})
>>> # same environment with numpy observations
>>> envgym = gym.make("DMC-cheetah-v0", transform=CatTensors(in_keys=["position", "velocity"], out_key="observation"), to_numpy=True)
>>> envgym.reset()
({'observation': array([-0.11355747, 0.04257728, 0.00408397, 0.04155852, -0.0389733 ,
-0.01409826, -0.0978704 , -0.08808327, 0.03970837, 0.00535434,
-0.02353762, 0.05116226, 0.02788907, 0.06848346, 0.05154399,
0.0371798 , 0.05128025])}, {})
>>> # If gymnasium is installed, we can register the environment there too.
>>> DMControlEnv.register_gym("DMC-cheetah-v0", to_numpy=False, backend="gymnasium", env_name="cheetah", task_name="run")
>>> import gymnasium
>>> envgym = gymnasium.make("DMC-cheetah-v0")
>>> envgym.seed(0)
>>> torch.manual_seed(0)
>>> envgym.reset()
({'position': tensor([-0.0855, 0.0215, -0.0881, -0.0412, -0.1101, 0.0080, 0.0254, 0.0424],
dtype=torch.float64), 'velocity': tensor([ 1.9609e-02, -1.9776e-04, -1.6347e-03, 3.3842e-02, 2.5338e-02,
3.3064e-02, 1.0381e-04, 7.6656e-05, 1.0204e-02],
dtype=torch.float64)}, {})
.. note::
This feature also works for stateless environments (eg, :class:`~torchrl.envs.BraxEnv`).
>>> import gymnasium
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.envs import BraxEnv, SelectTransform
>>>
>>> # get action for dydactic purposes
>>> env = BraxEnv("ant", batch_size=[2])
>>> env.set_seed(0)
>>> torch.manual_seed(0)
>>> td = env.rollout(10)
>>>
>>> actions = td.get("action")
>>>
>>> # register env
>>> env.register_gym("Brax-Ant-v0", env_name="ant", batch_size=[2], info_keys=["state"])
>>> gym_env = gymnasium.make("Brax-Ant-v0")
>>> gym_env.seed(0)
>>> torch.manual_seed(0)
>>>
>>> gym_env.reset()
>>> obs = []
>>> for i in range(10):
... obs, reward, terminated, truncated, info = gym_env.step(td[..., i].get("action"))
"""
from torchrl.envs.libs.gym import gym_backend, set_gym_backend
if backend is None:
backend = gym_backend()
with set_gym_backend(backend):
return cls._register_gym(
id=id,
entry_point=entry_point,
transform=transform,
info_keys=info_keys,
to_numpy=to_numpy,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker,
apply_api_compatibility=apply_api_compatibility,
**kwargs,
)
_GYM_UNRECOGNIZED_KWARG = (
"The keyword argument {} is not compatible with gym version {}"
)
@implement_for("gym", "0.26", None, class_method=True)
def _register_gym(
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
import gym
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
to_numpy=to_numpy,
transform=transform,
**kwargs,
)
return gym.register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker,
apply_api_compatibility=apply_api_compatibility,
)
@implement_for("gym", "0.25", "0.26", class_method=True)
def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
import gym
if apply_api_compatibility is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
)
)
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
to_numpy=to_numpy,
transform=transform,
**kwargs,
)
return gym.register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker,
)
@implement_for("gym", "0.24", "0.25", class_method=True)
def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
import gym
if apply_api_compatibility is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
)
)
if disable_env_checker is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"disable_env_checker", gym.__version__
)
)
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
to_numpy=to_numpy,
transform=transform,
**kwargs,
)
return gym.register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
autoreset=autoreset,
)
@implement_for("gym", "0.21", "0.24", class_method=True)
def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
import gym
if apply_api_compatibility is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
)
)
if disable_env_checker is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"disable_env_checker", gym.__version__
)
)
if autoreset is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
)
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
to_numpy=to_numpy,
transform=transform,
**kwargs,
)
return gym.register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
)
@implement_for("gym", None, "0.21", class_method=True)
def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
import gym
from torchrl.envs.libs._gym_utils import _TorchRLGymWrapper
if order_enforce is not True:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format("order_enforce", gym.__version__)
)
if disable_env_checker is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"disable_env_checker", gym.__version__
)
)
if autoreset is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format("autoreset", gym.__version__)
)
if apply_api_compatibility is not False:
raise TypeError(
cls._GYM_UNRECOGNIZED_KWARG.format(
"apply_api_compatibility", gym.__version__
)
)
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
_TorchRLGymWrapper,
entry_point=entry_point,
info_keys=info_keys,
to_numpy=to_numpy,
transform=transform,
**kwargs,
)
return gym.register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
)
@implement_for("gymnasium", class_method=True)
def _register_gym( # noqa: F811
cls,
id,
entry_point: Callable | None = None,
transform: "Transform" | None = None, # noqa: F821
info_keys: List[NestedKey] | None = None,
to_numpy: bool = False,
reward_threshold: float | None = None,
nondeterministic: bool = False,
max_episode_steps: int | None = None,
order_enforce: bool = True,
autoreset: bool = False,
disable_env_checker: bool = False,
apply_api_compatibility: bool = False,
**kwargs,
):
import gymnasium
from torchrl.envs.libs._gym_utils import _TorchRLGymnasiumWrapper
if entry_point is None:
entry_point = cls
entry_point = functools.partial(
_TorchRLGymnasiumWrapper,
entry_point=entry_point,
info_keys=info_keys,
to_numpy=to_numpy,
transform=transform,
**kwargs,
)
return gymnasium.register(
id=id,
entry_point=entry_point,
reward_threshold=reward_threshold,
nondeterministic=nondeterministic,
max_episode_steps=max_episode_steps,
order_enforce=order_enforce,
autoreset=autoreset,
disable_env_checker=disable_env_checker,
apply_api_compatibility=apply_api_compatibility,
)
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
raise NotImplementedError("EnvBase.forward is not implemented")
@abc.abstractmethod
def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
raise NotImplementedError
@abc.abstractmethod
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
raise NotImplementedError
[docs] def reset(
self,
tensordict: Optional[TensorDictBase] = None,
**kwargs,
) -> TensorDictBase:
"""Resets the environment.
As for step and _step, only the private method :obj:`_reset` should be overwritten by EnvBase subclasses.
Args:
tensordict (TensorDictBase, optional): tensordict to be used to contain the resulting new observation.
In some cases, this input can also be used to pass argument to the reset function.
kwargs (optional): other arguments to be passed to the native
reset function.
Returns:
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.
"""
if tensordict is not None:
self._assert_tensordict_shape(tensordict)
tensordict_reset = self._reset(tensordict, **kwargs)
# We assume that this is done properly
# if reset.device != self.device:
# reset = reset.to(self.device, non_blocking=True)
if tensordict_reset is tensordict:
raise RuntimeError(
"EnvBase._reset should return outplace changes to the input "
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty()) "
"inside _reset before writing new tensors onto this new instance."
)
if not isinstance(tensordict_reset, TensorDictBase):
raise RuntimeError(
f"env._reset returned an object of type {type(tensordict_reset)} but a TensorDict was expected."
)
return self._reset_proc_data(tensordict, tensordict_reset)
def _reset_proc_data(self, tensordict, tensordict_reset):
self._complete_done(self.full_done_spec, tensordict_reset)
self._reset_check_done(tensordict, tensordict_reset)
if tensordict is not None:
return _update_during_reset(tensordict_reset, tensordict, self.reset_keys)
return tensordict_reset
def _reset_check_done(self, tensordict, tensordict_reset):
"""Checks the done status after reset.
If _reset signals were passed, we check that the env is not done for these
indices.
We also check that the input tensordict contained ``"done"``s if the
reset is partial and incomplete.
"""
# we iterate over (reset_key, (done_key, truncated_key)) and check that all
# values where reset was true now have a done set to False.
# If no reset was present, all done and truncated must be False
for reset_key, done_key_group in zip(self.reset_keys, self.done_keys_groups):
reset_value = (
tensordict.get(reset_key, default=None)
if tensordict is not None
else None
)
if reset_value is not None:
for done_key in done_key_group:
done_val = tensordict_reset.get(done_key)
if done_val[reset_value].any() and not self._allow_done_after_reset:
raise RuntimeError(
f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
)
if (
done_key not in tensordict.keys(True)
and done_val[~reset_value].any()
):
warnings.warn(
f"A partial `'_reset'` key has been passed to `reset` ({reset_key}), "
f"but the corresponding done_key ({done_key}) was not present in the input "
f"tensordict. "
f"This is discouraged, since the input tensordict should contain "
f"all the data not being reset."
)
# we set the done val to tensordict, to make sure that
# _update_during_reset does not pad the value
tensordict.set(done_key, done_val)
elif not self._allow_done_after_reset:
for done_key in done_key_group:
if tensordict_reset.get(done_key).any():
raise RuntimeError(
f"The done entry '{done_key}' was (partially) True after a call to reset() in env {self}."
)
def numel(self) -> int:
return prod(self.batch_size)
[docs] def set_seed(
self, seed: Optional[int] = None, static_seed: bool = False
) -> Optional[int]:
"""Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
Args:
seed (int): seed to be set. The seed is set only locally in the environment. To handle the global seed,
see :func:`~torch.manual_seed`.
static_seed (bool, optional): if ``True``, the seed is not incremented.
Defaults to False
Returns:
integer representing the "next seed": i.e. the seed that should be
used for another environment if created concomitantly to this environment.
"""
self._set_seed(seed)
if seed is not None and not static_seed:
new_seed = seed_generator(seed)
seed = new_seed
return seed
@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError
def set_state(self):
raise NotImplementedError
def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
if (
self.batch_locked or self.batch_size != ()
) and tensordict.batch_size != self.batch_size:
raise RuntimeError(
f"Expected a tensordict with shape==env.batch_size, "
f"got {tensordict.batch_size} and {self.batch_size}"
)
[docs] def rand_action(self, tensordict: Optional[TensorDictBase] = None):
"""Performs a random action given the action_spec attribute.
Args:
tensordict (TensorDictBase, optional): tensordict where the resulting action should be written.
Returns:
a tensordict object with the "action" entry updated with a random
sample from the action-spec.
"""
shape = torch.Size([])
if not self.batch_locked:
if not self.batch_size and tensordict is not None:
# if we can't infer the batch-size from the env, take it from tensordict
shape = tensordict.shape
elif not self.batch_size:
# if tensordict wasn't provided, we assume empty batch size
shape = torch.Size([])
elif tensordict.shape != self.batch_size:
# if tensordict is not None and the env has a batch size, their shape must match
raise RuntimeError(
"The input tensordict and the env have a different batch size: "
f"env.batch_size={self.batch_size} and tensordict.batch_size={tensordict.shape}. "
f"Non batch-locked environment require the env batch-size to be either empty or to"
f" match the tensordict one."
)
# We generate the action from the full_action_spec
r = self.input_spec["full_action_spec"].rand(shape)
if tensordict is None:
return r
tensordict.update(r)
return tensordict
[docs] def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
"""Performs a random step in the environment given the action_spec attribute.
Args:
tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
Returns:
a tensordict object with the new observation after a random step in the environment. The action will
be stored with the "action" key.
"""
tensordict = self.rand_action(tensordict)
return self.step(tensordict)
@property
def specs(self) -> Composite:
"""Returns a Composite container where all the environment are present.
This feature allows one to create an environment, retrieve all of the specs in a single data container and then
erase the environment from the workspace.
"""
return Composite(
output_spec=self.output_spec,
input_spec=self.input_spec,
shape=self.batch_size,
).lock_()
@property
def _has_dynamic_specs(self) -> bool:
return _has_dynamic_specs(self.specs)
[docs] def rollout(
self,
max_steps: int,
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
*,
auto_reset: bool = True,
auto_cast_to_device: bool = False,
break_when_any_done: bool | None = None,
break_when_all_done: bool | None = None,
return_contiguous: bool = True,
tensordict: Optional[TensorDictBase] = None,
set_truncated: bool = False,
out=None,
trust_policy: bool = False,
) -> TensorDictBase:
"""Executes a rollout in the environment.
The function will return as soon as any of the contained environments
reaches any of the done states.
Args:
max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if
the environment reaches a done state before max_steps have been executed.
policy (callable, optional): callable to be called to compute the desired action.
If no policy is provided, actions will be called using :obj:`env.rand_step()`.
The policy can be any callable that reads either a tensordict or
the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``.
Defaults to `None`.
callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given
TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
responsibility to save any result within the callback call if data needs to be carried over beyond
the call to ``rollout``.
Keyword Args:
auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the
rollout. If ``False``, then the rollout will continue from a previous state, which requires the
``tensordict`` argument to be passed with the previous rollout. Default is ``True``.
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
policy device before the policy is used. Default is ``False``.
break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the
done states. If ``False``, then the done environments are reset automatically. Default is ``True``.
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
of the done states. If ``False``, break if at least one environment reaches any of the done states.
Default is ``False``.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
environment in those dimensions (if needed).
This normally should not occur if ``tensordict`` is the output of a reset, but can occur
if ``tensordict`` is the last step of a previous rollout.
A ``tensordict`` can also be provided when ``auto_reset=True`` if metadata need to be passed
to the ``reset`` method, such as a batch-size or a device for stateless environments.
set_truncated (bool, optional): if ``True``, ``"truncated"`` and ``"done"`` keys will be set to
``True`` after completion of the rollout. If no ``"truncated"`` is found within the
``done_spec``, an exception is raised.
Truncated keys can be set through ``env.add_truncated_keys``.
Defaults to ``False``.
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
and ``False`` otherwise.
Returns:
TensorDict object containing the resulting trajectory.
The data returned will be marked with a "time" dimension name for the last
dimension of the tensordict (at the ``env.ndim`` index).
``rollout`` is quite handy to display what the data structure of the
environment looks like.
Examples:
>>> # Using rollout without a policy
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs.transforms import TransformedEnv, StepCounter
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20))
>>> rollout = env.rollout(max_steps=1000)
>>> print(rollout)
TensorDict(
fields={
action: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([20]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([20]),
device=cpu,
is_shared=False)
>>> print(rollout.names)
['time']
>>> # with envs that contain more dimensions
>>> from torchrl.envs import SerialEnv
>>> env = SerialEnv(3, lambda: TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)))
>>> rollout = env.rollout(max_steps=1000)
>>> print(rollout)
TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 20]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 20]),
device=cpu,
is_shared=False)
>>> print(rollout.names)
[None, 'time']
Using a policy (a regular :class:`~torch.nn.Module` or a :class:`~tensordict.nn.TensorDictModule`)
is also easy:
Examples:
>>> from torch import nn
>>> env = GymEnv("CartPole-v1", categorical_action_encoding=True)
>>> class ArgMaxModule(nn.Module):
... def forward(self, values):
... return values.argmax(-1)
>>> n_obs = env.observation_spec["observation"].shape[-1]
>>> n_act = env.action_spec.n
>>> # A deterministic policy
>>> policy = nn.Sequential(
... nn.Linear(n_obs, n_act),
... ArgMaxModule())
>>> env.rollout(max_steps=10, policy=policy)
TensorDict(
fields={
action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> # Under the hood, rollout will wrap the policy in a TensorDictModule
>>> # To speed things up we can do that ourselves
>>> from tensordict.nn import TensorDictModule
>>> policy = TensorDictModule(policy, in_keys=list(env.observation_spec.keys()), out_keys=["action"])
>>> env.rollout(max_steps=10, policy=policy)
TensorDict(
fields={
action: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
In some instances, contiguous tensordict cannot be obtained because
they cannot be stacked. This can happen when the data returned at
each step may have a different shape, or when different environments
are executed together. In that case, ``return_contiguous=False``
will cause the returned tensordict to be a lazy stack of tensordicts:
Examples of non-contiguous rollout:
>>> rollout = env.rollout(4, return_contiguous=False)
>>> print(rollout)
LazyStackedTensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: LazyStackedTensorDict(
fields={
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=cpu,
is_shared=False)
>>> print(rollout.names)
[None, 'time']
Rollouts can be used in a loop to emulate data collection.
To do so, you need to pass as input the last tensordict coming from the previous rollout after calling
:func:`~torchrl.envs.utils.step_mdp` on it.
Examples of data collection rollouts:
>>> from torchrl.envs import GymEnv, step_mdp
>>> env = GymEnv("CartPole-v1")
>>> epochs = 10
>>> input_td = env.reset()
>>> for i in range(epochs):
... rollout_td = env.rollout(
... max_steps=100,
... break_when_any_done=False,
... auto_reset=False,
... tensordict=input_td,
... )
... input_td = step_mdp(
... rollout_td[..., -1],
... )
"""
if break_when_any_done is None: # True by default
if break_when_all_done: # all overrides
break_when_any_done = False
else:
break_when_any_done = True
if break_when_all_done is None:
# There is no case where break_when_all_done is True by default
break_when_all_done = False
if break_when_all_done and break_when_any_done:
raise TypeError(
"Cannot have both break_when_all_done and break_when_any_done True at the same time."
)
if policy is not None:
policy = _make_compatible_policy(
policy,
self.observation_spec,
env=self,
fast_wrap=True,
trust_policy=trust_policy,
)
if auto_cast_to_device:
try:
policy_device = next(policy.parameters()).device
except (StopIteration, AttributeError):
policy_device = None
else:
policy_device = None
else:
policy = self.rand_action
policy_device = None
env_device = self.device
if auto_reset:
tensordict = self.reset(tensordict)
elif tensordict is None:
raise RuntimeError("tensordict must be provided when auto_reset is False")
else:
tensordict = self.maybe_reset(tensordict)
kwargs = {
"tensordict": tensordict,
"auto_cast_to_device": auto_cast_to_device,
"max_steps": max_steps,
"policy": policy,
"policy_device": policy_device,
"env_device": env_device,
"callback": callback,
}
if break_when_any_done or break_when_all_done:
tensordicts = self._rollout_stop_early(
break_when_all_done=break_when_all_done,
break_when_any_done=break_when_any_done,
**kwargs,
)
else:
tensordicts = self._rollout_nonstop(**kwargs)
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
if return_contiguous:
try:
out_td = torch.stack(tensordicts, len(batch_size), out=out)
except RuntimeError as err:
if (
"The shapes of the tensors to stack is incompatible" in str(err)
and self._has_dynamic_specs
):
raise RuntimeError(
"The environment specs are dynamic. Call rollout with return_contiguous=False."
)
raise
else:
out_td = LazyStackedTensorDict.maybe_dense_stack(
tensordicts, len(batch_size), out=out
)
if set_truncated:
found_truncated = False
for key in self.done_keys:
if _ends_with(key, "truncated"):
val = out_td.get(("next", key))
done = out_td.get(("next", _replace_last(key, "done")))
val[(slice(None),) * (out_td.ndim - 1) + (-1,)] = True
out_td.set(("next", key), val)
out_td.set(("next", _replace_last(key, "done")), val | done)
found_truncated = True
if not found_truncated:
raise RuntimeError(
"set_truncated was set to True but no truncated key could be found. "
"Make sure a 'truncated' entry was set in the environment "
"full_done_keys using `env.add_truncated_keys()`."
)
out_td.refine_names(..., "time")
return out_td
[docs] def add_truncated_keys(self) -> EnvBase:
"""Adds truncated keys to the environment."""
for key in self.done_keys:
self.full_done_spec[_replace_last(key, "truncated")] = self.full_done_spec[
key
]
self.__dict__["_done_keys"] = None
return self
@property
def _step_mdp(self):
step_func = self.__dict__.get("_step_mdp_value")
if step_func is None:
step_func = _StepMDP(self, exclude_action=False)
self.__dict__["_step_mdp_value"] = step_func
return step_func
def _rollout_stop_early(
self,
*,
break_when_any_done,
break_when_all_done,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
# Get the sync func
if auto_cast_to_device:
sync_func = _get_sync_func(policy_device, env_device)
tensordicts = []
partial_steps = True
for i in range(max_steps):
if auto_cast_to_device:
if policy_device is not None:
tensordict = tensordict.to(policy_device, non_blocking=True)
sync_func()
else:
tensordict.clear_device_()
tensordict = policy(tensordict)
if auto_cast_to_device:
if env_device is not None:
tensordict = tensordict.to(env_device, non_blocking=True)
sync_func()
else:
tensordict.clear_device_()
tensordict = self.step(tensordict)
td_append = tensordict.copy()
if break_when_all_done:
if partial_steps is not True:
# At least one partial step has been done
del td_append["_step"]
td_append = torch.where(
partial_steps.view(td_append.shape), td_append, tensordicts[-1]
)
tensordicts.append(td_append)
if i == max_steps - 1:
# we don't truncate as one could potentially continue the run
break
tensordict = self._step_mdp(tensordict)
if break_when_any_done:
# done and truncated are in done_keys
# We read if any key is done.
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key=None,
)
if any_done:
break
else:
_terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_step",
write_full_false=False,
)
partial_step_curr = tensordict.get("_step", None)
if partial_step_curr is not None:
partial_step_curr = ~partial_step_curr
partial_steps = partial_steps & partial_step_curr
if partial_steps is not True:
if not partial_steps.any():
break
tensordict.set("_step", partial_steps)
if callback is not None:
callback(self, tensordict)
return tensordicts
def _rollout_nonstop(
self,
*,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
if auto_cast_to_device:
sync_func = _get_sync_func(policy_device, env_device)
tensordicts = []
tensordict_ = tensordict
for i in range(max_steps):
if auto_cast_to_device:
if policy_device is not None:
tensordict_ = tensordict_.to(policy_device, non_blocking=True)
sync_func()
else:
tensordict_.clear_device_()
tensordict_ = policy(tensordict_)
if auto_cast_to_device:
if env_device is not None:
tensordict_ = tensordict_.to(env_device, non_blocking=True)
sync_func()
else:
tensordict_.clear_device_()
if i == max_steps - 1:
tensordict = self.step(tensordict_)
else:
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
tensordicts.append(tensordict)
if i == max_steps - 1:
# we don't truncate as one could potentially continue the run
break
if callback is not None:
callback(self, tensordict)
return tensordicts
[docs] def step_and_maybe_reset(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
"""Runs a step in the environment and (partially) resets it if needed.
Args:
tensordict (TensorDictBase): an input data structure for the :meth:`~.step`
method.
This method allows to easily code non-stopping rollout functions.
Examples:
>>> from torchrl.envs import ParallelEnv, GymEnv
>>> def rollout(env, n):
... data_ = env.reset()
... result = []
... for i in range(n):
... data, data_ = env.step_and_maybe_reset(data_)
... result.append(data)
... return torch.stack(result)
>>> env = ParallelEnv(2, lambda: GymEnv("CartPole-v1"))
>>> print(rollout(env, 2))
TensorDict(
fields={
done: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([2, 2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2, 2]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([2, 2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2, 2]),
device=cpu,
is_shared=False)
"""
if tensordict.device != self.device:
tensordict = tensordict.to(self.device)
tensordict = self.step(tensordict)
# done and truncated are in done_keys
# We read if any key is done.
tensordict_ = self._step_mdp(tensordict)
tensordict_ = self.maybe_reset(tensordict_)
return tensordict, tensordict_
@property
def _simple_done(self):
_simple_done = self.__dict__.get("_simple_done_value")
if _simple_done is None:
key_set = set(self.full_done_spec.keys())
_simple_done = key_set == {
"done",
"truncated",
"terminated",
} or key_set == {"done", "terminated"}
self.__dict__["_simple_done_value"] = _simple_done
return _simple_done
[docs] def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Checks the done keys of the input tensordict and, if needed, resets the environment where it is done.
Args:
tensordict (TensorDictBase): a tensordict coming from the output of :func:`~torchrl.envs.utils.step_mdp`.
Returns:
A tensordict that is identical to the input where the environment was
not reset and contains the new reset data where the environment was reset.
"""
if self._simple_done:
done = tensordict._get_str("done", default=None)
any_done = done.any()
if any_done:
tensordict._set_str(
"_reset",
done.clone(),
validated=True,
inplace=False,
non_blocking=False,
)
else:
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_reset",
)
if any_done:
tensordict = self.reset(tensordict)
return tensordict
[docs] def empty_cache(self):
"""Erases all the cached values.
For regular envs, the key lists (reward, done etc) are cached, but in some cases
they may change during the execution of the code (eg, when adding a transform).
"""
self.__dict__["_step_mdp_value"] = None
self.__dict__["_reward_keys"] = None
self.__dict__["_done_keys"] = None
self.__dict__["_action_keys"] = None
self.__dict__["_state_keys"] = None
self.__dict__["_done_keys_group"] = None
@property
def reset_keys(self) -> List[NestedKey]:
"""Returns a list of reset keys.
Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
settings. They are structured as ``(*prefix, "_reset")`` where ``prefix`` is
a (possibly empty) tuple of strings pointing to a tensordict location
where a done state can be found.
Keys are sorted by depth in the data tree.
"""
reset_keys = self.__dict__.get("_reset_keys")
if reset_keys is not None:
return reset_keys
reset_keys = sorted(
(
_replace_last(done_key, "_reset")
for (done_key, *_) in self.done_keys_groups
),
key=_repr_by_depth,
)
self.__dict__["_reset_keys"] = reset_keys
return reset_keys
@property
def _filtered_reset_keys(self):
"""Returns only the effective reset keys, discarding nested resets if they're not being used."""
reset_keys = self.reset_keys
result = []
def _root(key):
if isinstance(key, str):
return ()
return key[:-1]
roots = []
for reset_key in reset_keys:
cur_root = _root(reset_key)
for root in roots:
if cur_root[: len(root)] == root:
break
else:
roots.append(cur_root)
result.append(reset_key)
return result
@property
def done_keys_groups(self):
"""A list of done keys, grouped as the reset keys.
This is a list of lists. The outer list has the length of reset keys, the
inner lists contain the done keys (eg, done and truncated) that can
be read to determine a reset when it is absent.
"""
done_keys_group = self.__dict__.get("_done_keys_group")
if done_keys_group is not None:
return done_keys_group
# done keys, sorted as reset keys
done_keys_group = []
roots = set()
fds = self.full_done_spec
for done_key in self.done_keys:
root_name = done_key[:-1] if isinstance(done_key, tuple) else ()
root = fds[root_name] if root_name else fds
n = len(roots)
roots.add(root_name)
if len(roots) - n:
done_keys_group.append(
[
unravel_key(root_name + (key,))
for key in root.keys(include_nested=False, leaves_only=True)
]
)
self.__dict__["_done_keys_group"] = done_keys_group
return done_keys_group
def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]:
for key in tensordict.keys():
if key.rfind("observation") >= 0:
yield key
def close(self):
self.is_closed = True
def __del__(self):
# if del occurs before env has been set up, we don't want a recursion
# error
if "is_closed" in self.__dict__ and not self.is_closed:
try:
self.close()
except Exception:
# a TypeError will typically be raised if the env is deleted when the program ends.
# In the future, insignificant changes to the close method may change the error type.
# We excplicitely assume that any error raised during closure in
# __del__ will not affect the program.
pass
[docs] def to(self, device: DEVICE_TYPING) -> EnvBase:
device = _make_ordinal_device(torch.device(device))
if device == self.device:
return self
self.__dict__["_input_spec"] = self.input_spec.to(device).lock_()
self.__dict__["_output_spec"] = self.output_spec.to(device).lock_()
self._device = device
return super().to(device)
[docs] def fake_tensordict(self) -> TensorDictBase:
"""Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout."""
state_spec = self.state_spec
observation_spec = self.observation_spec
action_spec = self.input_spec["full_action_spec"]
# instantiates reward_spec if needed
_ = self.reward_spec
reward_spec = self.output_spec["full_reward_spec"]
full_done_spec = self.output_spec["full_done_spec"]
fake_obs = observation_spec.zero()
fake_reward = reward_spec.zero()
fake_done = full_done_spec.zero()
fake_state = state_spec.zero()
fake_action = action_spec.zero()
if any(
isinstance(val, LazyStackedTensorDict) for val in fake_action.values(True)
):
fake_input = fake_action.update(fake_state)
else:
fake_input = fake_state.update(fake_action)
# the input and output key may match, but the output prevails
# Hence we generate the input, and override using the output
fake_in_out = fake_input.update(fake_obs)
next_output = fake_obs.clone()
next_output.update(fake_reward)
next_output.update(fake_done)
fake_in_out.update(fake_done.clone())
if "next" not in fake_in_out.keys():
fake_in_out.set("next", next_output)
else:
fake_in_out.get("next").update(next_output)
fake_in_out.batch_size = self.batch_size
fake_in_out = fake_in_out.to(self.device)
return fake_in_out
class _EnvWrapper(EnvBase):
"""Abstract environment wrapper class.
Unlike EnvBase, _EnvWrapper comes with a :obj:`_build_env` private method that will be called upon instantiation.
Interfaces with other libraries should be coded using _EnvWrapper.
It is possible to directly query attributed from the nested environment it its name does not conflict with
an attribute of the wrapper:
>>> env = SomeWrapper(...)
>>> custom_attribute0 = env._env.custom_attribute
>>> custom_attribute1 = env.custom_attribute
>>> assert custom_attribute0 is custom_attribute1 # should return True
"""
git_url: str = ""
available_envs: Dict[str, Any] = {}
libname: str = ""
def __init__(
self,
*args,
device: DEVICE_TYPING = None,
batch_size: Optional[torch.Size] = None,
allow_done_after_reset: bool = False,
**kwargs,
):
super().__init__(
device=device,
batch_size=batch_size,
allow_done_after_reset=allow_done_after_reset,
)
if len(args):
raise ValueError(
"`_EnvWrapper.__init__` received a non-empty args list of arguments. "
"Make sure only keywords arguments are used when calling `super().__init__`."
)
frame_skip = kwargs.get("frame_skip", 1)
if "frame_skip" in kwargs:
del kwargs["frame_skip"]
self.frame_skip = frame_skip
# this value can be changed if frame_skip is passed during env construction
self.wrapper_frame_skip = frame_skip
self._constructor_kwargs = kwargs
self._check_kwargs(kwargs)
self._convert_actions_to_numpy = kwargs.pop("convert_actions_to_numpy", True)
self._env = self._build_env(**kwargs) # writes the self._env attribute
self._make_specs(self._env) # writes the self._env attribute
self.is_closed = False
self._init_env() # runs all the steps to have a ready-to-use env
def _sync_device(self):
sync_func = self.__dict__.get("_sync_device_val")
if sync_func is None:
device = self.device
if device.type != "cuda":
if torch.cuda.is_available():
self._sync_device_val = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_device_val = torch.cuda.synchronize
elif device.type == "cpu":
self._sync_device_val = _do_nothing
else:
self._sync_device_val = _do_nothing
return self._sync_device
return sync_func
@abc.abstractmethod
def _check_kwargs(self, kwargs: Dict):
raise NotImplementedError
def __getattr__(self, attr: str) -> Any:
if attr in self.__dir__():
return self.__getattribute__(
attr
) # make sure that appropriate exceptions are raised
elif attr.startswith("__"):
raise AttributeError(
"passing built-in private methods is "
f"not permitted with type {type(self)}. "
f"Got attribute {attr}."
)
elif "_env" in self.__dir__():
env = self.__getattribute__("_env")
return getattr(env, attr)
super().__getattr__(attr)
raise AttributeError(
f"env not set in {self.__class__.__name__}, cannot access {attr}"
)
@abc.abstractmethod
def _init_env(self) -> Optional[int]:
"""Runs all the necessary steps such that the environment is ready to use.
This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
is reset (if needed). For instance, DMControl envs require the env to be reset before being used, but Gym envs
don't.
Returns:
the resulting seed
"""
raise NotImplementedError
@abc.abstractmethod
def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
"""Creates an environment from the target library and stores it with the `_env` attribute.
When overwritten, this function should pass all the required kwargs to the env instantiation method.
"""
raise NotImplementedError
@abc.abstractmethod
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
raise NotImplementedError
def close(self) -> None:
"""Closes the contained environment if possible."""
self.is_closed = True
try:
self._env.close()
except AttributeError:
pass
def make_tensordict(
env: _EnvWrapper,
policy: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
) -> TensorDictBase:
"""Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
Args:
env (_EnvWrapper): environment defining the observation, action and reward space;
policy (Callable, optional): policy corresponding to the environment.
"""
with torch.no_grad():
tensordict = env.reset()
if policy is not None:
tensordict = policy(tensordict)
else:
tensordict.set("action", env.action_spec.rand(), inplace=False)
tensordict = env.step(tensordict)
return tensordict.zero_()
def _get_sync_func(policy_device, env_device):
if torch.cuda.is_available():
# Look for a specific device
if policy_device is not None and policy_device.type == "cuda":
if env_device is None or env_device.type == "cuda":
return torch.cuda.synchronize
return functools.partial(torch.cuda.synchronize, device=policy_device)
if env_device is not None and env_device.type == "cuda":
if policy_device is None:
return torch.cuda.synchronize
return functools.partial(torch.cuda.synchronize, device=env_device)
return torch.cuda.synchronize
if torch.backends.mps.is_available():
return torch.mps.synchronize
return _do_nothing
def _do_nothing():
return
def _has_dynamic_specs(spec: Composite):
from tensordict.base import _NESTED_TENSORS_AS_LISTS
return any(
any(s == -1 for s in spec.shape)
for spec in spec.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
)