ActionDiscretizer¶
- class torchrl.envs.transforms.ActionDiscretizer(num_intervals: int | torch.Tensor, action_key: NestedKey = 'action', out_action_key: Optional[NestedKey] = None, sampling=None, categorical: bool = True)[source]¶
A transform to discretize a continuous action space.
This transform makes it possible to use an algorithm designed for discrete action spaces such as DQN over environments with a continuous action space.
- Parameters:
num_intervals (int or torch.Tensor) – the number of discrete values for each element of the action space. If a single integer is provided, all action items are sliced with the same number of elements. If a tensor is provided, it must have the same number of elements as the action space (ie, the length of the
num_intervals
tensor must match the last dimension of the action space).action_key (NestedKey, optional) – the action key to use. Points to the action of the parent env (the floating point action). Defaults to
"action"
.out_action_key (NestedKey, optional) – the key where the discrete action should be written. If
None
is provided, it defaults to the value ofaction_key
. If both keys do not match, the continuous action_spec is moved from thefull_action_spec
environment attribute to thefull_state_spec
container, as only the discrete action should be sampled for an action to be taken. Providingout_action_key
can ensure that the floating point action is available to be recorded.sampling (ActionDiscretizer.SamplingStrategy, optinoal) – an element of the
ActionDiscretizer.SamplingStrategy
IntEnum
object (MEDIAN
,LOW
,HIGH
orRANDOM
). Indicates how the continuous action should be sampled in the provided interval.categorical (bool, optional) – if
False
, one-hot encoding is used. Defaults toTrue
.
Examples
>>> from torchrl.envs import GymEnv, check_env_specs >>> import torch >>> base_env = GymEnv("HalfCheetah-v4") >>> num_intervals = torch.arange(5, 11) >>> categorical = True >>> sampling = ActionDiscretizer.SamplingStrategy.MEDIAN >>> t = ActionDiscretizer( ... num_intervals=num_intervals, ... categorical=categorical, ... sampling=sampling, ... out_action_key="action_disc", ... ) >>> env = base_env.append_transform(t) TransformedEnv( env=GymEnv(env=HalfCheetah-v4, batch_size=torch.Size([]), device=cpu), transform=ActionDiscretizer( num_intervals=tensor([ 5, 6, 7, 8, 9, 10]), action_key=action, out_action_key=action_disc,, sampling=0, categorical=True)) >>> check_env_specs(env) >>> # Produce a rollout >>> r = env.rollout(4) >>> print(r) TensorDict( fields={ action: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.float32, is_shared=False), action_disc: Tensor(shape=torch.Size([4, 6]), device=cpu, dtype=torch.int64, is_shared=False), done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([4, 17]), device=cpu, dtype=torch.float64, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=cpu, is_shared=False) >>> assert r["action"].dtype == torch.float >>> assert r["action_disc"].dtype == torch.int64 >>> assert (r["action"] < base_env.action_spec.high).all() >>> assert (r["action"] > base_env.action_spec.low).all()
- transform_input_spec(input_spec)[source]¶
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform