Shortcuts

Source code for torchrl.envs.libs.dm_control

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import collections

import importlib
import os
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch

from torchrl._utils import logger as torchrl_logger, VERBOSE

from torchrl.data.tensor_specs import (
    BoundedTensorSpec,
    CompositeSpec,
    DiscreteTensorSpec,
    OneHotDiscreteTensorSpec,
    TensorSpec,
    UnboundedContinuousTensorSpec,
    UnboundedDiscreteTensorSpec,
)

from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict
from torchrl.envs.gym_like import GymLikeEnv
from torchrl.envs.utils import _classproperty

if torch.cuda.device_count() > 1:
    n = torch.cuda.device_count() - 1
    os.environ["EGL_DEVICE_ID"] = str(1 + (os.getpid() % n))
    if VERBOSE:
        torchrl_logger.info(f"EGL_DEVICE_ID: {os.environ['EGL_DEVICE_ID']}")

_has_dmc = _has_dm_control = importlib.util.find_spec("dm_control") is not None

__all__ = ["DMControlEnv", "DMControlWrapper"]


def _dmcontrol_to_torchrl_spec_transform(
    spec,
    dtype: Optional[torch.dtype] = None,
    device: DEVICE_TYPING = None,
    categorical_discrete_encoding: bool = False,
) -> TensorSpec:
    import dm_env

    if isinstance(spec, collections.OrderedDict) or isinstance(spec, Dict):
        spec = {
            k: _dmcontrol_to_torchrl_spec_transform(
                item,
                device=device,
                categorical_discrete_encoding=categorical_discrete_encoding,
            )
            for k, item in spec.items()
        }
        return CompositeSpec(**spec)
    elif isinstance(spec, dm_env.specs.DiscreteArray):
        # DiscreteArray is a type of BoundedArray so this block needs to go first
        action_space_cls = (
            DiscreteTensorSpec
            if categorical_discrete_encoding
            else OneHotDiscreteTensorSpec
        )
        if dtype is None:
            dtype = (
                numpy_to_torch_dtype_dict[spec.dtype]
                if categorical_discrete_encoding
                else torch.long
            )
        return action_space_cls(spec.num_values, device=device, dtype=dtype)
    elif isinstance(spec, dm_env.specs.BoundedArray):
        if dtype is None:
            dtype = numpy_to_torch_dtype_dict[spec.dtype]
        shape = spec.shape
        if not len(shape):
            shape = torch.Size([1])
        return BoundedTensorSpec(
            shape=shape,
            low=spec.minimum,
            high=spec.maximum,
            dtype=dtype,
            device=device,
        )
    elif isinstance(spec, dm_env.specs.Array):
        shape = spec.shape
        if not len(shape):
            shape = torch.Size([1])
        if dtype is None:
            dtype = numpy_to_torch_dtype_dict[spec.dtype]
        if dtype in (torch.float, torch.double, torch.half):
            return UnboundedContinuousTensorSpec(
                shape=shape, dtype=dtype, device=device
            )
        else:
            return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device)
    else:
        raise NotImplementedError(type(spec))


def _get_envs(to_dict: bool = True) -> Dict[str, Any]:
    if not _has_dm_control:
        raise ImportError("Cannot find dm_control in virtual environment.")
    from dm_control import suite

    if not to_dict:
        return tuple(suite.BENCHMARKING) + tuple(suite.EXTRA)
    d = {}
    for tup in suite.BENCHMARKING:
        env_name = tup[0]
        d.setdefault(env_name, []).append(tup[1])
    for tup in suite.EXTRA:
        env_name = tup[0]
        d.setdefault(env_name, []).append(tup[1])
    return d.items()


def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor:
    if isinstance(array, np.ndarray):
        return torch.as_tensor(array.copy())
    else:
        return torch.as_tensor(array)


[docs]class DMControlWrapper(GymLikeEnv): """DeepMind Control lab environment wrapper. The DeepMind control library can be found here: https://github.com/deepmind/dm_control. Paper: https://arxiv.org/abs/2006.12983 Args: env (dm_control.suite env): :class:`~dm_control.suite.base.Task` environment instance. Keyword Args: from_pixels (bool, optional): if ``True``, an attempt to return the pixel observations from the env will be performed. By default, these observations will be written under the ``"pixels"`` entry. Defaults to ``False``. pixels_only (bool, optional): if ``True``, only the pixel observations will be returned (by default under the ``"pixels"`` entry in the output tensordict). If ``False``, observations (eg, states) and pixels will be returned whenever ``from_pixels=True``. Defaults to ``True``. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum of rewards across steps. device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. batch_size (torch.Size, optional): the batch size of the environment. Should match the leading dimensions of all observations, done states, rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. Attributes: available_envs (list): a list of ``Tuple[str, List[str]]`` representing the environment / task pairs available. Examples: >>> from dm_control import suite >>> from torchrl.envs import DMControlWrapper >>> env = suite.load("cheetah", "run") >>> env = DMControlWrapper(env, ... from_pixels=True, frame_skip=4) >>> td = env.rand_step() >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([240, 320, 3]), device=cpu, dtype=torch.uint8, is_shared=False), position: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float64, is_shared=False), reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float64, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), velocity: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> print(env.available_envs) [('acrobot', ['swingup', 'swingup_sparse']), ('ball_in_cup', ['catch']), ('cartpole', ['balance', 'balance_sparse', 'swingup', 'swingup_sparse', 'three_poles', 'two_poles']), ('cheetah', ['run']), ('finger', ['spin', 'turn_easy', 'turn_hard']), ('fish', ['upright', 'swim']), ('hopper', ['stand', 'hop']), ('humanoid', ['stand', 'walk', 'run', 'run_pure_state']), ('manipulator', ['bring_ball', 'bring_peg', 'insert_ball', 'insert_peg']), ('pendulum', ['swingup']), ('point_mass', ['easy', 'hard']), ('reacher', ['easy', 'hard']), ('swimmer', ['swimmer6', 'swimmer15']), ('walker', ['stand', 'walk', 'run']), ('dog', ['fetch', 'run', 'stand', 'trot', 'walk']), ('humanoid_CMU', ['run', 'stand', 'walk']), ('lqr', ['lqr_2_1', 'lqr_6_2']), ('quadruped', ['escape', 'fetch', 'run', 'walk']), ('stacker', ['stack_2', 'stack_4'])] """ git_url = "https://github.com/deepmind/dm_control" libname = "dm_control" @_classproperty def available_envs(cls): if not _has_dm_control: return [] return list(_get_envs()) @property def lib(self): import dm_control return dm_control def __init__(self, env=None, **kwargs): if env is not None: kwargs["env"] = env super().__init__(**kwargs) def _build_env( self, env, _seed: Optional[int] = None, from_pixels: bool = False, render_kwargs: Optional[dict] = None, pixels_only: bool = False, camera_id: Union[int, str] = 0, **kwargs, ): self.from_pixels = from_pixels self.pixels_only = pixels_only if from_pixels: from dm_control.suite.wrappers import pixels self._set_egl_device(self.device) self.render_kwargs = {"camera_id": camera_id} if render_kwargs is not None: self.render_kwargs.update(render_kwargs) env = pixels.Wrapper( env, pixels_only=self.pixels_only, render_kwargs=self.render_kwargs, ) return env def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 # specs are defined when first called self.observation_spec = _dmcontrol_to_torchrl_spec_transform( self._env.observation_spec(), device=self.device ) reward_spec = _dmcontrol_to_torchrl_spec_transform( self._env.reward_spec(), device=self.device ) if len(reward_spec.shape) == 0: reward_spec.shape = torch.Size([1]) self.reward_spec = reward_spec # populate default done spec done_spec = DiscreteTensorSpec( n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device ) self.done_spec = CompositeSpec( done=done_spec.clone(), truncated=done_spec.clone(), terminated=done_spec.clone(), device=self.device, ) self.action_spec = _dmcontrol_to_torchrl_spec_transform( self._env.action_spec(), device=self.device ) def _check_kwargs(self, kwargs: Dict): dm_control = self.lib from dm_control.suite.wrappers import pixels if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, (dm_control.rl.control.Environment, pixels.Wrapper)): raise TypeError( "env is not of type 'dm_control.rl.control.Environment' or `dm_control.suite.wrappers.pixels.Wrapper`." ) def _set_egl_device(self, device: DEVICE_TYPING): # Deprecated as lead to unreliable rendering # egl device needs to be set before importing mujoco bindings: in # distributed settings, it'll be easy to tell which cuda device to use. # In mp settings, we'll need to use mp.Pool with a specific init function # that defines the EGL device before importing libraries. For now, we'll # just use a common EGL_DEVICE_ID environment variable for all processes. return def to(self, device: DEVICE_TYPING) -> DMControlEnv: super().to(device) self._set_egl_device(self.device) return self def _init_env(self, seed: Optional[int] = None) -> Optional[int]: seed = self.set_seed(seed) return seed def _set_seed(self, _seed: Optional[int]) -> Optional[int]: from dm_control.suite.wrappers import pixels if _seed is None: return None random_state = np.random.RandomState(_seed) if isinstance(self._env, pixels.Wrapper): if not hasattr(self._env._env.task, "_random"): raise RuntimeError("self._env._env.task._random does not exist") self._env._env.task._random = random_state else: if not hasattr(self._env.task, "_random"): raise RuntimeError("self._env._env.task._random does not exist") self._env.task._random = random_state self.reset() return _seed def _output_transform( self, timestep_tuple: Tuple["TimeStep"] # noqa: F821 ) -> Tuple[np.ndarray, float, bool, bool, dict]: from dm_env import StepType if type(timestep_tuple) is not tuple: timestep_tuple = (timestep_tuple,) reward = timestep_tuple[0].reward truncated = terminated = False if timestep_tuple[0].step_type == StepType.LAST: if np.isclose(timestep_tuple[0].discount, 1): truncated = True else: terminated = True done = truncated or terminated observation = timestep_tuple[0].observation info = {} return observation, reward, terminated, truncated, done, info def _reset_output_transform(self, reset_data): ( observation, reward, terminated, truncated, done, info, ) = self._output_transform(reset_data) return observation, info def __repr__(self) -> str: return ( f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" )
[docs]class DMControlEnv(DMControlWrapper): """DeepMind Control lab environment wrapper. The DeepMind control library can be found here: https://github.com/deepmind/dm_control. Paper: https://arxiv.org/abs/2006.12983 Args: env_name (str): name of the environment. task_name (str): name of the task. Keyword Args: from_pixels (bool, optional): if ``True``, an attempt to return the pixel observations from the env will be performed. By default, these observations will be written under the ``"pixels"`` entry. Defaults to ``False``. pixels_only (bool, optional): if ``True``, only the pixel observations will be returned (by default under the ``"pixels"`` entry in the output tensordict). If ``False``, observations (eg, states) and pixels will be returned whenever ``from_pixels=True``. Defaults to ``True``. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum of rewards across steps. device (torch.device, optional): if provided, the device on which the data is to be cast. Defaults to ``torch.device("cpu")``. batch_size (torch.Size, optional): the batch size of the environment. Should match the leading dimensions of all observations, done states, rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. Attributes: available_envs (list): a list of ``Tuple[str, List[str]]`` representing the environment / task pairs available. Examples: >>> from torchrl.envs import DMControlEnv >>> env = DMControlEnv(env_name="cheetah", task_name="run", ... from_pixels=True, frame_skip=4) >>> td = env.rand_step() >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float64, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), pixels: Tensor(shape=torch.Size([240, 320, 3]), device=cpu, dtype=torch.uint8, is_shared=False), position: Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float64, is_shared=False), reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float64, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), velocity: Tensor(shape=torch.Size([9]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False) >>> print(env.available_envs) [('acrobot', ['swingup', 'swingup_sparse']), ('ball_in_cup', ['catch']), ('cartpole', ['balance', 'balance_sparse', 'swingup', 'swingup_sparse', 'three_poles', 'two_poles']), ('cheetah', ['run']), ('finger', ['spin', 'turn_easy', 'turn_hard']), ('fish', ['upright', 'swim']), ('hopper', ['stand', 'hop']), ('humanoid', ['stand', 'walk', 'run', 'run_pure_state']), ('manipulator', ['bring_ball', 'bring_peg', 'insert_ball', 'insert_peg']), ('pendulum', ['swingup']), ('point_mass', ['easy', 'hard']), ('reacher', ['easy', 'hard']), ('swimmer', ['swimmer6', 'swimmer15']), ('walker', ['stand', 'walk', 'run']), ('dog', ['fetch', 'run', 'stand', 'trot', 'walk']), ('humanoid_CMU', ['run', 'stand', 'walk']), ('lqr', ['lqr_2_1', 'lqr_6_2']), ('quadruped', ['escape', 'fetch', 'run', 'walk']), ('stacker', ['stack_2', 'stack_4'])] """ def __init__(self, env_name, task_name, **kwargs): if not _has_dmc: raise ImportError( "dm_control python package was not found. Please install this dependency." ) kwargs["env_name"] = env_name kwargs["task_name"] = task_name super().__init__(**kwargs) def _build_env( self, env_name: str, task_name: str, _seed: Optional[int] = None, **kwargs, ): from dm_control import suite self.env_name = env_name self.task_name = task_name from_pixels = kwargs.get("from_pixels") if "from_pixels" in kwargs: del kwargs["from_pixels"] pixels_only = kwargs.get("pixels_only") if "pixels_only" in kwargs: del kwargs["pixels_only"] if not _has_dmc: raise ImportError( f"dm_control not found, unable to create {env_name}:" f" {task_name}. Consider downloading and installing " f"dm_control from {self.git_url}" ) if _seed is not None: random_state = np.random.RandomState(_seed) kwargs = {"random": random_state} camera_id = kwargs.pop("camera_id", 0) env = suite.load(env_name, task_name, task_kwargs=kwargs) return super()._build_env( env, from_pixels=from_pixels, pixels_only=pixels_only, camera_id=camera_id, **kwargs, ) def rebuild_with_kwargs(self, **new_kwargs): self._constructor_kwargs.update(new_kwargs) self._env = self._build_env() self._make_specs(self._env) def _check_kwargs(self, kwargs: Dict): if "env_name" in kwargs: env_name = kwargs["env_name"] if "task_name" in kwargs: task_name = kwargs["task_name"] available_envs = dict(self.available_envs) if ( env_name not in available_envs or task_name not in available_envs[env_name] ): raise RuntimeError( f"{env_name} with task {task_name} is unknown in {self.libname}" ) else: raise TypeError("dm_control requires task_name to be specified") else: raise TypeError("dm_control requires env_name to be specified") def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.env_name}, task={self.task_name}, batch_size={self.batch_size})"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources