RemoveEmptySpecs¶
- class torchrl.envs.transforms.RemoveEmptySpecs(in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, in_keys_inv: Optional[Sequence[NestedKey]] = None, out_keys_inv: Optional[Sequence[NestedKey]] = None)[source]¶
Removes empty specs and content from an environment.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import Unbounded, Composite, ... Categorical >>> from torchrl.envs import EnvBase, TransformedEnv, RemoveEmptySpecs >>> >>> >>> class DummyEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) ... self.observation_spec = Composite( ... observation=UnboundedContinuous((*self.batch_size, 3)), ... other=Composite( ... another_other=Composite(shape=self.batch_size), ... shape=self.batch_size, ... ), ... shape=self.batch_size, ... ) ... self.action_spec = UnboundedContinuous((*self.batch_size, 3)) ... self.done_spec = Categorical( ... 2, (*self.batch_size, 1), dtype=torch.bool ... ) ... self.full_done_spec["truncated"] = self.full_done_spec[ ... "terminated"].clone() ... self.reward_spec = Composite( ... reward=UnboundedContinuous(*self.batch_size, 1), ... other_reward=Composite(shape=self.batch_size), ... shape=self.batch_size ... ) ... ... def _reset(self, tensordict): ... return self.observation_spec.rand().update(self.full_done_spec.zero()) ... ... def _step(self, tensordict): ... return TensorDict( ... {}, ... batch_size=[] ... ).update(self.observation_spec.rand()).update( ... self.full_done_spec.zero() ... ).update(self.full_reward_spec.rand()) ... ... def _set_seed(self, seed): ... return seed + 1 >>> >>> >>> base_env = DummyEnv() >>> print(base_env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), other: TensorDict( fields={ another_other: TensorDict( fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), other_reward: TensorDict( fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) >>> check_env_specs(base_env) >>> env = TransformedEnv(base_env, RemoveEmptySpecs()) >>> print(env.rollout(2)) TensorDict( fields={ action: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False) check_env_specs(env)
- forward(tensordict: TensorDictBase) TensorDictBase ¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]¶
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_output_spec(output_spec: Composite) Composite [source]¶
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransformfull_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform