Shortcuts

Source code for torchrl.envs.libs.brax

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib.util
import warnings

from typing import Dict, Optional, Union

import torch
from packaging import version
from tensordict import TensorDict, TensorDictBase

from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.libs.jax_utils import (
    _extract_spec,
    _ndarray_to_tensor,
    _object_to_tensordict,
    _tensor_to_ndarray,
    _tensordict_to_object,
    _tree_flatten,
    _tree_reshape,
)
from torchrl.envs.utils import _classproperty

_has_brax = importlib.util.find_spec("brax") is not None


def _get_envs():
    if not _has_brax:
        raise ImportError("BRAX is not installed in your virtual environment.")

    import brax.envs

    return list(brax.envs._envs.keys())


[docs]class BraxWrapper(_EnvWrapper): """Google Brax environment wrapper. Brax offers a vectorized and differentiable simulation framework based on Jax. TorchRL's wrapper incurs some overhead for the jax-to-torch conversion, but computational graphs can still be built on top of the simulated trajectories, allowing for backpropagation through the rollout. GitHub: https://github.com/google/brax Paper: https://arxiv.org/abs/2106.13281 Args: env (brax.envs.base.PipelineEnv): the environment to wrap. categorical_action_encoding (bool, optional): if ``True``, categorical specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: from_pixels (bool, optional): Not yet supported. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum of rewards across steps. device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. batch_size (torch.Size, optional): the batch size of the environment. In ``brax``, this indicates the number of vectorized environments. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. Attributes: available_envs: environments availalbe to build Examples: >>> import brax.envs >>> from torchrl.envs import BraxWrapper >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> base_env = brax.envs.get_environment("ant") >>> env = BraxWrapper(base_env, device=device) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() >>> td = env.step(td) >>> print(td) TensorDict( fields={ action: Tensor(torch.Size([8]), dtype=torch.float32), done: Tensor(torch.Size([1]), dtype=torch.bool), next: TensorDict( fields={ observation: Tensor(torch.Size([87]), dtype=torch.float32)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(torch.Size([87]), dtype=torch.float32), reward: Tensor(torch.Size([1]), dtype=torch.float32), state: TensorDict(...)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> print(env.available_envs) ['acrobot', 'ant', 'fast', 'fetch', ...] To take advante of Brax, one usually executes multiple environments at the same time. In the following example, we iteratively test different batch sizes and report the execution time for a short rollout: Examples: >>> import torch >>> from torch.utils.benchmark import Timer >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> for batch_size in [4, 16, 128]: ... timer = Timer(''' ... env.rollout(100) ... ''', ... setup=f''' ... import brax.envs ... from torchrl.envs import BraxWrapper ... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}") ... env.set_seed(0) ... env.rollout(2) ... ''') ... print(batch_size, timer.timeit(10)) 4 env.rollout(100) setup: [...] 310.00 ms 1 measurement, 10 runs , 1 thread 16 env.rollout(100) setup: [...] 268.46 ms 1 measurement, 10 runs , 1 thread 128 env.rollout(100) setup: [...] 433.80 ms 1 measurement, 10 runs , 1 thread One can backpropagate through the rollout and optimize the policy directly: >>> import brax.envs >>> from torchrl.envs import BraxWrapper >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> import torch >>> >>> env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[10], requires_grad=True) >>> env.set_seed(0) >>> torch.manual_seed(0) >>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"]) >>> >>> td = env.rollout(10, policy) >>> >>> td["next", "reward"].mean().backward(retain_graph=True) >>> print(policy.module.weight.grad.norm()) tensor(213.8605) """ git_url = "https://github.com/google/brax" @_classproperty def available_envs(cls): if not _has_brax: return [] return list(_get_envs()) libname = "brax" _lib = None _jax = None @_classproperty def lib(cls): if cls._lib is not None: return cls._lib import brax import brax.envs cls._lib = brax return brax @_classproperty def jax(cls): if cls._jax is not None: return cls._jax import jax cls._jax = jax return jax def __init__(self, env=None, categorical_action_encoding=False, **kwargs): if env is not None: kwargs["env"] = env self._seed_calls_reset = None self._categorical_action_encoding = categorical_action_encoding super().__init__(**kwargs) if not self.device: warnings.warn( f"No device is set for env {self}. " f"Setting a device in Brax wrapped environments is strongly recommended." ) def _check_kwargs(self, kwargs: Dict): brax = self.lib if version.parse(brax.__version__) < version.parse("0.10.4"): raise ImportError("Brax v0.10.4 or greater is required.") if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, brax.envs.Env): raise TypeError("env is not of type 'brax.envs.Env'.") def _build_env( self, env, _seed: Optional[int] = None, from_pixels: bool = False, render_kwargs: Optional[dict] = None, pixels_only: bool = False, requires_grad: bool = False, camera_id: Union[int, str] = 0, **kwargs, ): self.from_pixels = from_pixels self.pixels_only = pixels_only self.requires_grad = requires_grad if from_pixels: raise NotImplementedError( "from_pixels=True is not yest supported within BraxWrapper" ) return env def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 jax = self.jax key = jax.random.PRNGKey(0) state = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) state_spec = _extract_spec(state_dict).expand(self.batch_size) return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.action_spec = Bounded( low=-1, high=1, shape=( *self.batch_size, env.action_size, ), device=self.device, ) self.reward_spec = Unbounded( shape=[ *self.batch_size, 1, ], device=self.device, ) self.observation_spec = Composite( observation=Unbounded( shape=( *self.batch_size, env.observation_size, ), device=self.device, ), shape=self.batch_size, ) # extract state spec from instance state_spec = self._make_state_spec(env) self.state_spec["state"] = state_spec self.observation_spec["state"] = state_spec.clone() def _make_state_example(self): jax = self.jax key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) state = _tree_reshape(state, self.batch_size) return state def _init_env(self) -> Optional[int]: jax = self.jax self._key = None self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset)) self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step)) self._state_example = self._make_state_example() def _set_seed(self, seed: int): jax = self.jax if seed is None: raise Exception("Brax requires an integer seed.") self._key = jax.random.PRNGKey(seed) def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: jax = self.jax # generate random keys self._key, *keys = jax.random.split(self._key, 1 + self.numel()) # call env reset with jit and vmap state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) # reshape batch size state = _tree_reshape(state, self.batch_size) state = _object_to_tensordict(state, self.device, self.batch_size) # build result state["reward"] = state.get("reward").view(*self.reward_spec.shape) state["done"] = state.get("done").view(*self.reward_spec.shape) done = state["done"].bool() tensordict_out = TensorDict._new_unsafe( source={ "observation": state.get("obs"), # "reward": reward, "done": done, "terminated": done.clone(), "state": state, }, batch_size=self.batch_size, device=self.device, ) return tensordict_out def _step_without_grad(self, tensordict: TensorDictBase): # convert tensors to ndarrays state = _tensordict_to_object(tensordict.get("state"), self._state_example) action = _tensor_to_ndarray(tensordict.get("action")) # flatten batch size state = _tree_flatten(state, self.batch_size) action = _tree_flatten(action, self.batch_size) # call env step with jit and vmap next_state = self._vmap_jit_env_step(state, action) # reshape batch size and convert ndarrays to tensors next_state = _tree_reshape(next_state, self.batch_size) next_state = _object_to_tensordict(next_state, self.device, self.batch_size) # build result next_state.set("reward", next_state.get("reward").view(self.reward_spec.shape)) next_state.set("done", next_state.get("done").view(self.reward_spec.shape)) done = next_state["done"].bool() reward = next_state["reward"] tensordict_out = TensorDict._new_unsafe( source={ "observation": next_state.get("obs"), "reward": reward, "done": done, "terminated": done.clone(), "state": next_state, }, batch_size=self.batch_size, device=self.device, ) return tensordict_out def _step_with_grad(self, tensordict: TensorDictBase): # convert tensors to ndarrays action = tensordict.get("action") state = tensordict.get("state") qp_keys, qp_values = zip(*state.get("pipeline_state").items()) # call env step with autograd function next_state_nograd, next_obs, next_reward, *next_qp_values = _BraxEnvStep.apply( self, state, action, *qp_values ) # extract done values: we assume a shape identical to reward next_done = next_state_nograd.get("done").view(*self.reward_spec.shape) next_reward = next_reward.view(*self.reward_spec.shape) # merge with tensors with grad function next_state = next_state_nograd next_state["obs"] = next_obs next_state.set("reward", next_reward) next_state.set("done", next_done) next_done = next_done.bool() next_state.get("pipeline_state").update(dict(zip(qp_keys, next_qp_values))) # build result tensordict_out = TensorDict._new_unsafe( source={ "observation": next_obs, "reward": next_reward, "done": next_done, "terminated": next_done, "state": next_state, }, batch_size=self.batch_size, device=self.device, ) return tensordict_out def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: if self.requires_grad: out = self._step_with_grad(tensordict) else: out = self._step_without_grad(tensordict) return out
[docs]class BraxEnv(BraxWrapper): """Google Brax environment wrapper built with the environment name. Brax offers a vectorized and differentiable simulation framework based on Jax. TorchRL's wrapper incurs some overhead for the jax-to-torch conversion, but computational graphs can still be built on top of the simulated trajectories, allowing for backpropagation through the rollout. GitHub: https://github.com/google/brax Paper: https://arxiv.org/abs/2106.13281 Args: env_name (str): the environment name of the env to wrap. Must be part of :attr:`~.available_envs`. categorical_action_encoding (bool, optional): if ``True``, categorical specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: from_pixels (bool, optional): Not yet supported. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum of rewards across steps. device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. batch_size (torch.Size, optional): the batch size of the environment. In ``brax``, this indicates the number of vectorized environments. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. Attributes: available_envs: environments availalbe to build Examples: >>> from torchrl.envs import BraxEnv >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> env = BraxEnv("ant", device=device) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() >>> td = env.step(td) >>> print(td) TensorDict( fields={ action: Tensor(torch.Size([8]), dtype=torch.float32), done: Tensor(torch.Size([1]), dtype=torch.bool), next: TensorDict( fields={ observation: Tensor(torch.Size([87]), dtype=torch.float32)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(torch.Size([87]), dtype=torch.float32), reward: Tensor(torch.Size([1]), dtype=torch.float32), state: TensorDict(...)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> print(env.available_envs) ['acrobot', 'ant', 'fast', 'fetch', ...] To take advante of Brax, one usually executes multiple environments at the same time. In the following example, we iteratively test different batch sizes and report the execution time for a short rollout: Examples: >>> import torch >>> from torch.utils.benchmark import Timer >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> for batch_size in [4, 16, 128]: ... timer = Timer(''' ... env.rollout(100) ... ''', ... setup=f''' ... from torchrl.envs import BraxEnv ... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}") ... env.set_seed(0) ... env.rollout(2) ... ''') ... print(batch_size, timer.timeit(10)) 4 env.rollout(100) setup: [...] 310.00 ms 1 measurement, 10 runs , 1 thread 16 env.rollout(100) setup: [...] 268.46 ms 1 measurement, 10 runs , 1 thread 128 env.rollout(100) setup: [...] 433.80 ms 1 measurement, 10 runs , 1 thread One can backpropagate through the rollout and optimize the policy directly: >>> from torchrl.envs import BraxEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> import torch >>> >>> env = BraxEnv("ant", batch_size=[10], requires_grad=True) >>> env.set_seed(0) >>> torch.manual_seed(0) >>> policy = TensorDictModule(nn.Linear(27, 8), in_keys=["observation"], out_keys=["action"]) >>> >>> td = env.rollout(10, policy) >>> >>> td["next", "reward"].mean().backward(retain_graph=True) >>> print(policy.module.weight.grad.norm()) tensor(213.8605) """ def __init__(self, env_name, **kwargs): kwargs["env_name"] = env_name super().__init__(**kwargs) def _build_env( self, env_name: str, **kwargs, ) -> "brax.envs.env.Env": # noqa: F821 if not _has_brax: raise ImportError( f"brax not found, unable to create {env_name}. " f"Consider downloading and installing brax from" f" {self.git_url}" ) from_pixels = kwargs.pop("from_pixels", False) pixels_only = kwargs.pop("pixels_only", True) requires_grad = kwargs.pop("requires_grad", False) if kwargs: raise ValueError("kwargs not supported.") self.wrapper_frame_skip = 1 env = self.lib.envs.get_environment(env_name, **kwargs) return super()._build_env( env, pixels_only=pixels_only, from_pixels=from_pixels, requires_grad=requires_grad, ) @property def env_name(self): return self._constructor_kwargs["env_name"] def _check_kwargs(self, kwargs: Dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
class _BraxEnvStep(torch.autograd.Function): @staticmethod def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values): import jax # convert tensors to ndarrays state_obj = _tensordict_to_object(state_td, env._state_example) action_nd = _tensor_to_ndarray(action_tensor) # flatten batch size state = _tree_flatten(state_obj, env.batch_size) action = _tree_flatten(action_nd, env.batch_size) # call vjp with jit and vmap next_state, vjp_fn = jax.vjp(env._vmap_jit_env_step, state, action) # reshape batch size next_state_reshape = _tree_reshape(next_state, env.batch_size) # convert ndarrays to tensors next_state_tensor = _object_to_tensordict( next_state_reshape, device=env.device, batch_size=env.batch_size ) # save context ctx.vjp_fn = vjp_fn ctx.next_state = next_state_tensor ctx.env = env return ( next_state_tensor, # no gradient next_state_tensor["obs"], next_state_tensor["reward"], *next_state_tensor["pipeline_state"].values(), ) @staticmethod def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): pipeline_state = dict( zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values) ) none_keys = [] def _make_none(key, val): if val is not None: return val none_keys.append(key) return torch.zeros_like(ctx.next_state.get(("pipeline_state", key))) pipeline_state = { key: _make_none(key, val) for key, val in pipeline_state.items() } metrics = ctx.next_state.get("metrics", None) if metrics is None: metrics = {} info = ctx.next_state.get("info", None) if info is None: info = {} grad_next_state_td = TensorDict( source={ "pipeline_state": pipeline_state, "obs": grad_next_obs, "reward": grad_next_reward, "done": torch.zeros_like(ctx.next_state.get("done")), "metrics": {k: torch.zeros_like(v) for k, v in metrics.items()}, "info": {k: torch.zeros_like(v) for k, v in info.items()}, }, device=ctx.env.device, batch_size=ctx.env.batch_size, ) # convert tensors to ndarrays grad_next_state_obj = _tensordict_to_object( grad_next_state_td, ctx.env._state_example ) # flatten batch size grad_next_state_flat = _tree_flatten(grad_next_state_obj, ctx.env.batch_size) # call vjp to get gradients grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat) # assert grad_action.device == ctx.env.device # reshape batch size grad_state = _tree_reshape(grad_state, ctx.env.batch_size) grad_action = _tree_reshape(grad_action, ctx.env.batch_size) # assert grad_action.device == ctx.env.device # convert ndarrays to tensors grad_state_qp = _object_to_tensordict( grad_state.pipeline_state, device=ctx.env.device, batch_size=ctx.env.batch_size, ) grad_action = _ndarray_to_tensor(grad_action).to(ctx.env.device) grad_state_qp = { key: val if key not in none_keys else None for key, val in grad_state_qp.items() } grads = (grad_action, *grad_state_qp.values()) return (None, None, *grads)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources