Source code for torchrl.envs.libs.brax

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib.util

from typing import Dict, Optional, Union

import torch
from tensordict.tensordict import TensorDict, TensorDictBase

from import (
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import _classproperty

_has_brax = importlib.util.find_spec("brax") is not None
from torchrl.envs.libs.jax_utils import (

def _get_envs():
    if not _has_brax:
        raise ImportError("BRAX is not installed in your virtual environment.")

    import brax.envs

    return list(brax.envs._envs.keys())

[docs]class BraxWrapper(_EnvWrapper): """Google Brax environment wrapper. Examples: >>> env = brax.envs.get_environment("ant") >>> env = BraxWrapper(env) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() >>> td = env.step(td) >>> print(td) TensorDict( fields={ action: Tensor(torch.Size([8]), dtype=torch.float32), done: Tensor(torch.Size([1]), dtype=torch.bool), next: TensorDict( fields={ observation: Tensor(torch.Size([87]), dtype=torch.float32)}, batch_size=torch.Size([]), device=cpu, is_shared=False), observation: Tensor(torch.Size([87]), dtype=torch.float32), reward: Tensor(torch.Size([1]), dtype=torch.float32), state: TensorDict(...)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> print(env.available_envs) ['acrobot', 'ant', 'fast', 'fetch', ...] """ git_url = "" @_classproperty def available_envs(cls): if not _has_brax: return yield from _get_envs() libname = "brax" _lib = None _jax = None @_classproperty def lib(cls): if cls._lib is not None: return cls._lib import brax import brax.envs cls._lib = brax return brax @_classproperty def jax(cls): if cls._jax is not None: return cls._jax import jax cls._jax = jax return jax def __init__(self, env=None, categorical_action_encoding=False, **kwargs): if env is not None: kwargs["env"] = env self._seed_calls_reset = None self._categorical_action_encoding = categorical_action_encoding super().__init__(**kwargs) def _check_kwargs(self, kwargs: Dict): brax = self.lib if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, brax.envs.env.Env): raise TypeError("env is not of type 'brax.envs.env.Env'.") def _build_env( self, env, _seed: Optional[int] = None, from_pixels: bool = False, render_kwargs: Optional[dict] = None, pixels_only: bool = False, requires_grad: bool = False, camera_id: Union[int, str] = 0, **kwargs, ): self.from_pixels = from_pixels self.pixels_only = pixels_only self.requires_grad = requires_grad if from_pixels: raise NotImplementedError("TODO") return env def _make_state_spec(self, env: "brax.envs.env.Env"): # noqa: F821 jax = self.jax key = jax.random.PRNGKey(0) state = env.reset(key) state_dict = _object_to_tensordict(state, self.device, batch_size=()) state_spec = _extract_spec(state_dict).expand(self.batch_size) return state_spec def _make_specs(self, env: "brax.envs.env.Env") -> None: # noqa: F821 self.action_spec = BoundedTensorSpec( low=-1, high=1, shape=( *self.batch_size, env.action_size, ), device=self.device, ) self.reward_spec = UnboundedContinuousTensorSpec( shape=[ *self.batch_size, 1, ], device=self.device, ) self.observation_spec = CompositeSpec( observation=UnboundedContinuousTensorSpec( shape=( *self.batch_size, env.observation_size, ), device=self.device, ), shape=self.batch_size, ) # extract state spec from instance state_spec = self._make_state_spec(env) self.state_spec["state"] = state_spec self.observation_spec["state"] = state_spec.clone() def _make_state_example(self): jax = self.jax key = jax.random.PRNGKey(0) keys = jax.random.split(key, self.batch_size.numel()) state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) state = _tree_reshape(state, self.batch_size) return state def _init_env(self) -> Optional[int]: jax = self.jax self._key = None self._vmap_jit_env_reset = jax.vmap(jax.jit(self._env.reset)) self._vmap_jit_env_step = jax.vmap(jax.jit(self._env.step)) self._state_example = self._make_state_example() def _set_seed(self, seed: int): jax = self.jax if seed is None: raise Exception("Brax requires an integer seed.") self._key = jax.random.PRNGKey(seed) def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase: jax = self.jax # generate random keys self._key, *keys = jax.random.split(self._key, 1 + self.numel()) # call env reset with jit and vmap state = self._vmap_jit_env_reset(jax.numpy.stack(keys)) # reshape batch size state = _tree_reshape(state, self.batch_size) state = _object_to_tensordict(state, self.device, self.batch_size) # build result state["reward"] = state.get("reward").view(*self.reward_spec.shape) state["done"] = state.get("done").view(*self.reward_spec.shape) done = state["done"].bool() tensordict_out = TensorDict( source={ "observation": state.get("obs"), # "reward": reward, "done": done, "terminated": done.clone(), "state": state, }, batch_size=self.batch_size, device=self.device, _run_checks=False, ) return tensordict_out def _step_without_grad(self, tensordict: TensorDictBase): # convert tensors to ndarrays state = _tensordict_to_object(tensordict.get("state"), self._state_example) action = _tensor_to_ndarray(tensordict.get("action")) # flatten batch size state = _tree_flatten(state, self.batch_size) action = _tree_flatten(action, self.batch_size) # call env step with jit and vmap next_state = self._vmap_jit_env_step(state, action) # reshape batch size and convert ndarrays to tensors next_state = _tree_reshape(next_state, self.batch_size) next_state = _object_to_tensordict(next_state, self.device, self.batch_size) # build result next_state.set("reward", next_state.get("reward").view(self.reward_spec.shape)) next_state.set("done", next_state.get("done").view(self.reward_spec.shape)) done = next_state["done"].bool() reward = next_state["reward"] tensordict_out = TensorDict( source={ "observation": next_state.get("obs"), "reward": reward, "done": done, "terminated": done.clone(), "state": next_state, }, batch_size=self.batch_size, device=self.device, _run_checks=False, ) return tensordict_out def _step_with_grad(self, tensordict: TensorDictBase): # convert tensors to ndarrays action = tensordict.get("action") state = tensordict.get("state") qp_keys, qp_values = zip(*state.get("pipeline_state").items()) # call env step with autograd function next_state_nograd, next_obs, next_reward, *next_qp_values = _BraxEnvStep.apply( self, state, action, *qp_values ) # extract done values: we assume a shape identical to reward next_done = next_state_nograd.get("done").view(*self.reward_spec.shape) next_reward = next_reward.view(*self.reward_spec.shape) # merge with tensors with grad function next_state = next_state_nograd next_state["obs"] = next_obs next_state.set("reward", next_reward) next_state.set("done", next_done) next_done = next_done.bool() next_state.get("pipeline_state").update(dict(zip(qp_keys, next_qp_values))) # build result tensordict_out = TensorDict( source={ "observation": next_obs, "reward": next_reward, "done": next_done, "terminated": next_done, "state": next_state, }, batch_size=self.batch_size, device=self.device, _run_checks=False, ) return tensordict_out def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: if self.requires_grad: out = self._step_with_grad(tensordict) else: out = self._step_without_grad(tensordict) return out
[docs]class BraxEnv(BraxWrapper): """Google Brax environment wrapper. Examples: >>> env = BraxEnv(env_name="ant") >>> td = env.rand_step() >>> print(td) >>> print(env.available_envs) """ def __init__(self, env_name, **kwargs): kwargs["env_name"] = env_name super().__init__(**kwargs) def _build_env( self, env_name: str, **kwargs, ) -> "brax.envs.env.Env": # noqa: F821 if not _has_brax: raise ImportError( f"brax not found, unable to create {env_name}. " f"Consider downloading and installing brax from" f" {self.git_url}" ) from_pixels = kwargs.pop("from_pixels", False) pixels_only = kwargs.pop("pixels_only", True) requires_grad = kwargs.pop("requires_grad", False) if kwargs: raise ValueError("kwargs not supported.") self.wrapper_frame_skip = 1 env = self.lib.envs.get_environment(env_name, **kwargs) return super()._build_env( env, pixels_only=pixels_only, from_pixels=from_pixels, requires_grad=requires_grad, ) @property def env_name(self): return self._constructor_kwargs["env_name"] def _check_kwargs(self, kwargs: Dict): if "env_name" not in kwargs: raise TypeError("Expected 'env_name' to be part of kwargs") def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"
class _BraxEnvStep(torch.autograd.Function): @staticmethod def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values): import jax # convert tensors to ndarrays state_obj = _tensordict_to_object(state_td, env._state_example) action_nd = _tensor_to_ndarray(action_tensor) # flatten batch size state = _tree_flatten(state_obj, env.batch_size) action = _tree_flatten(action_nd, env.batch_size) # call vjp with jit and vmap next_state, vjp_fn = jax.vjp(env._vmap_jit_env_step, state, action) # reshape batch size next_state_reshape = _tree_reshape(next_state, env.batch_size) # convert ndarrays to tensors next_state_tensor = _object_to_tensordict( next_state_reshape, device=env.device, batch_size=env.batch_size ) # save context ctx.vjp_fn = vjp_fn ctx.next_state = next_state_tensor ctx.env = env return ( next_state_tensor, # no gradient next_state_tensor["obs"], next_state_tensor["reward"], *next_state_tensor["pipeline_state"].values(), ) @staticmethod def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values): pipeline_state = dict( zip(ctx.next_state.get("pipeline_state").keys(), grad_next_qp_values) ) none_keys = [] def _make_none(key, val): if val is not None: return val none_keys.append(key) return torch.zeros_like(ctx.next_state.get(("pipeline_state", key))) pipeline_state = { key: _make_none(key, val) for key, val in pipeline_state.items() } metrics = ctx.next_state.get("metrics", None) if metrics is None: metrics = {} info = ctx.next_state.get("info", None) if info is None: info = {} grad_next_state_td = TensorDict( source={ "pipeline_state": pipeline_state, "obs": grad_next_obs, "reward": grad_next_reward, "done": torch.zeros_like(ctx.next_state.get("done")), "metrics": {k: torch.zeros_like(v) for k, v in metrics.items()}, "info": {k: torch.zeros_like(v) for k, v in info.items()}, }, device=ctx.env.device, batch_size=ctx.env.batch_size, ) # convert tensors to ndarrays grad_next_state_obj = _tensordict_to_object( grad_next_state_td, ctx.env._state_example ) # flatten batch size grad_next_state_flat = _tree_flatten(grad_next_state_obj, ctx.env.batch_size) # call vjp to get gradients grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat) # reshape batch size grad_state = _tree_reshape(grad_state, ctx.env.batch_size) grad_action = _tree_reshape(grad_action, ctx.env.batch_size) # convert ndarrays to tensors grad_state_qp = _object_to_tensordict( grad_state.pipeline_state, device=ctx.env.device, batch_size=ctx.env.batch_size, ) grad_action = _ndarray_to_tensor(grad_action) grad_state_qp = { key: val if key not in none_keys else None for key, val in grad_state_qp.items() } return (None, None, grad_action, *grad_state_qp.values())


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources