BraxEnv¶
- torchrl.envs.BraxEnv(*args, **kwargs)[source]¶
Google Brax environment wrapper built with the environment name.
Brax offers a vectorized and differentiable simulation framework based on Jax. TorchRL’s wrapper incurs some overhead for the jax-to-torch conversion, but computational graphs can still be built on top of the simulated trajectories, allowing for backpropagation through the rollout.
GitHub: https://github.com/google/brax
Paper: https://arxiv.org/abs/2106.13281
- Parameters:
env_name (str) – the environment name of the env to wrap. Must be part of
available_envs
.categorical_action_encoding (bool, optional) – if
True
, categorical specs will be converted to the TorchRL equivalent (torchrl.data.DiscreteTensorSpec
), otherwise a one-hot encoding will be used (torchrl.data.OneHotTensorSpec
). Defaults toFalse
.
- Keyword Arguments:
from_pixels (bool, optional) – Not yet supported.
frame_skip (int, optional) – if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum of rewards across steps.
device (torch.device, optional) – if provided, the device on which the data is to be cast. Defaults to
torch.device("cpu")
.batch_size (torch.Size, optional) – the batch size of the environment. In
brax
, this indicates the number of vectorized environments. Defaults totorch.Size([])
.allow_done_after_reset (bool, optional) – if
True
, it is tolerated for envs to bedone
just afterreset()
is called. Defaults toFalse
.
- Variables:
available_envs – environments availalbe to build
Examples
>>> from torchrl.envs import BraxEnv >>> env = BraxEnv("ant") >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() >>> td = env.step(td) >>> print(td) TensorDict( fields={ action: Tensor(torch.Size([8]), dtype=torch.float32), done: Tensor(torch.Size([1]), dtype=torch.bool), next: TensorDict( fields={ observation: Tensor(torch.Size([87]), dtype=torch.float32)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(torch.Size([87]), dtype=torch.float32), reward: Tensor(torch.Size([1]), dtype=torch.float32), state: TensorDict(...)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> print(env.available_envs) ['acrobot', 'ant', 'fast', 'fetch', ...]
To take advante of Brax, one usually executes multiple environments at the same time. In the following example, we iteratively test different batch sizes and report the execution time for a short rollout:
Examples
>>> for batch_size in [4, 16, 128]: ... timer = Timer(''' ... env.rollout(100) ... ''', ... setup=f''' ... from torchrl.envs import BraxEnv ... env = BraxEnv("ant", batch_size=[{batch_size}]) ... env.set_seed(0) ... env.rollout(2) ... ''') ... print(batch_size, timer.timeit(10)) 4 env.rollout(100) setup: [...] 310.00 ms 1 measurement, 10 runs , 1 thread
16 env.rollout(100) setup: […] 268.46 ms 1 measurement, 10 runs , 1 thread
128 env.rollout(100) setup: […] 433.80 ms 1 measurement, 10 runs , 1 thread
One can backpropagate through the rollout and optimize the policy directly:
>>> from torchrl.envs import BraxEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> import torch >>> >>> env = BraxEnv("ant", batch_size=[10], requires_grad=True) >>> env.set_seed(0) >>> torch.manual_seed(0) >>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"]) >>> >>> td = env.rollout(10, policy) >>> >>> td["next", "reward"].mean().backward(retain_graph=True) >>> print(policy.module.weight.grad.norm()) tensor(213.8605)