Shortcuts

RemoveEmptySpecs

class torchrl.envs.transforms.RemoveEmptySpecs(in_keys: Sequence[NestedKey] = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None)[source]

Removes empty specs and content from an environment.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec,         ...     DiscreteTensorSpec
>>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs
>>>
>>>
>>> class DummyEnv(EnvBase):
...     def __init__(self, *args, **kwargs):
...         super().__init__(*args, **kwargs)
...         self.observation_spec = CompositeSpec(
...             observation=UnboundedContinuousTensorSpec((*self.batch_size, 3)),
...             other=CompositeSpec(
...                 another_other=CompositeSpec(shape=self.batch_size),
...                 shape=self.batch_size,
...             ),
...             shape=self.batch_size,
...         )
...         self.action_spec = UnboundedContinuousTensorSpec((*self.batch_size, 3))
...         self.done_spec = DiscreteTensorSpec(
...             2, (*self.batch_size, 1), dtype=torch.bool
...         )
...         self.full_done_spec["truncated"] = self.full_done_spec[
...             "terminated"].clone()
...         self.reward_spec = CompositeSpec(
...             reward=UnboundedContinuousTensorSpec(*self.batch_size, 1),
...             other_reward=CompositeSpec(shape=self.batch_size),
...             shape=self.batch_size
...             )
...
...     def _reset(self, tensordict):
...         return self.observation_spec.rand().update(self.full_done_spec.zero())
...
...     def _step(self, tensordict):
...         return TensorDict(
...             {},
...             batch_size=[]
...         ).update(self.observation_spec.rand()).update(
...             self.full_done_spec.zero()
...             ).update(self.full_reward_spec.rand())
...
...     def _set_seed(self, seed):
...         return seed + 1
>>>
>>>
>>> base_env = DummyEnv()
>>> print(base_env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                other: TensorDict(
                    fields={
                        another_other: TensorDict(
                            fields={
                            },
                            batch_size=torch.Size([2]),
                            device=cpu,
                            is_shared=False)},
                    batch_size=torch.Size([2]),
                    device=cpu,
                    is_shared=False),
                other_reward: TensorDict(
                    fields={
                    },
                    batch_size=torch.Size([2]),
                    device=cpu,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
>>> check_env_specs(base_env)
>>> env = TransformedEnv(base_env, RemoveEmptySpecs())
>>> print(env.rollout(2))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2]),
    device=cpu,
    is_shared=False)
check_env_specs(env)
forward(tensordict: TensorDictBase) TensorDictBase

Reads the input tensordict, and for the selected keys, applies the transform.

transform_input_spec(input_spec: TensorSpec) TensorSpec[source]

Transforms the input spec such that the resulting spec matches transform mapping.

Parameters:

input_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

transform_output_spec(output_spec: CompositeSpec) CompositeSpec[source]

Transforms the output spec such that the resulting spec matches transform mapping.

This method should generally be left untouched. Changes should be implemented using transform_observation_spec(), transform_reward_spec() and transformfull_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources