Shortcuts

Source code for torchrl.envs.libs.dm_control

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import collections

import importlib
import os
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch

from torchrl._utils import VERBOSE

from torchrl.data.tensor_specs import (
    BoundedTensorSpec,
    CompositeSpec,
    DiscreteTensorSpec,
    TensorSpec,
    UnboundedContinuousTensorSpec,
    UnboundedDiscreteTensorSpec,
)

from torchrl.data.utils import DEVICE_TYPING, numpy_to_torch_dtype_dict
from torchrl.envs.gym_like import GymLikeEnv
from torchrl.envs.utils import _classproperty

if torch.cuda.device_count() > 1:
    n = torch.cuda.device_count() - 1
    os.environ["EGL_DEVICE_ID"] = str(1 + (os.getpid() % n))
    if VERBOSE:
        print("EGL_DEVICE_ID: ", os.environ["EGL_DEVICE_ID"])

_has_dmc = _has_dm_control = importlib.util.find_spec("dm_control") is not None

__all__ = ["DMControlEnv", "DMControlWrapper"]


def _dmcontrol_to_torchrl_spec_transform(
    spec,
    dtype: Optional[torch.dtype] = None,
    device: DEVICE_TYPING = None,
) -> TensorSpec:
    import dm_env

    if isinstance(spec, collections.OrderedDict):
        spec = {
            k: _dmcontrol_to_torchrl_spec_transform(item, device=device)
            for k, item in spec.items()
        }
        return CompositeSpec(**spec)
    elif isinstance(spec, dm_env.specs.BoundedArray):
        if dtype is None:
            dtype = numpy_to_torch_dtype_dict[spec.dtype]
        shape = spec.shape
        if not len(shape):
            shape = torch.Size([1])
        return BoundedTensorSpec(
            shape=shape,
            low=spec.minimum,
            high=spec.maximum,
            dtype=dtype,
            device=device,
        )
    elif isinstance(spec, dm_env.specs.Array):
        shape = spec.shape
        if not len(shape):
            shape = torch.Size([1])
        if dtype is None:
            dtype = numpy_to_torch_dtype_dict[spec.dtype]
        if dtype in (torch.float, torch.double, torch.half):
            return UnboundedContinuousTensorSpec(
                shape=shape, dtype=dtype, device=device
            )
        else:
            return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device)

    else:
        raise NotImplementedError(type(spec))


def _get_envs(to_dict: bool = True) -> Dict[str, Any]:
    if not _has_dm_control:
        raise ImportError("Cannot find dm_control in virtual environment.")
    from dm_control import suite

    if not to_dict:
        return tuple(suite.BENCHMARKING) + tuple(suite.EXTRA)
    d = {}
    for tup in suite.BENCHMARKING:
        env_name = tup[0]
        d.setdefault(env_name, []).append(tup[1])
    for tup in suite.EXTRA:
        env_name = tup[0]
        d.setdefault(env_name, []).append(tup[1])
    return d.items()


def _robust_to_tensor(array: Union[float, np.ndarray]) -> torch.Tensor:
    if isinstance(array, np.ndarray):
        return torch.tensor(array.copy())
    else:
        return torch.tensor(array)


[docs]class DMControlWrapper(GymLikeEnv): """DeepMind Control lab environment wrapper. Args: env (dm_control.suite env): environment instance from_pixels (bool): if ``True``, the observation Examples: >>> env = dm_control.suite.load("cheetah", "run") >>> env = DMControlWrapper(env, ... from_pixels=True, frame_skip=4) >>> td = env.rand_step() >>> print(td) >>> print(env.available_envs) """ git_url = "https://github.com/deepmind/dm_control" libname = "dm_control" @_classproperty def available_envs(cls): if not _has_dm_control: return yield from _get_envs() @property def lib(self): import dm_control return dm_control def __init__(self, env=None, **kwargs): if env is not None: kwargs["env"] = env super().__init__(**kwargs) def _build_env( self, env, _seed: Optional[int] = None, from_pixels: bool = False, render_kwargs: Optional[dict] = None, pixels_only: bool = False, camera_id: Union[int, str] = 0, **kwargs, ): self.from_pixels = from_pixels self.pixels_only = pixels_only if from_pixels: from dm_control.suite.wrappers import pixels self._set_egl_device(self.device) self.render_kwargs = {"camera_id": camera_id} if render_kwargs is not None: self.render_kwargs.update(render_kwargs) env = pixels.Wrapper( env, pixels_only=self.pixels_only, render_kwargs=self.render_kwargs, ) return env def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 # specs are defined when first called self.observation_spec = _dmcontrol_to_torchrl_spec_transform( self._env.observation_spec(), device=self.device ) reward_spec = _dmcontrol_to_torchrl_spec_transform( self._env.reward_spec(), device=self.device ) if len(reward_spec.shape) == 0: reward_spec.shape = torch.Size([1]) self.reward_spec = reward_spec # populate default done spec done_spec = DiscreteTensorSpec( n=2, shape=(*self.batch_size, 1), dtype=torch.bool, device=self.device ) self.done_spec = CompositeSpec( done=done_spec.clone(), truncated=done_spec.clone(), terminated=done_spec.clone(), device=self.device, ) self.action_spec = _dmcontrol_to_torchrl_spec_transform( self._env.action_spec(), device=self.device ) def _check_kwargs(self, kwargs: Dict): dm_control = self.lib from dm_control.suite.wrappers import pixels if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") env = kwargs["env"] if not isinstance(env, (dm_control.rl.control.Environment, pixels.Wrapper)): raise TypeError( "env is not of type 'dm_control.rl.control.Environment' or `dm_control.suite.wrappers.pixels.Wrapper`." ) def _set_egl_device(self, device: DEVICE_TYPING): # Deprecated as lead to unreliable rendering # egl device needs to be set before importing mujoco bindings: in # distributed settings, it'll be easy to tell which cuda device to use. # In mp settings, we'll need to use mp.Pool with a specific init function # that defines the EGL device before importing libraries. For now, we'll # just use a common EGL_DEVICE_ID environment variable for all processes. return def to(self, device: DEVICE_TYPING) -> DMControlEnv: super().to(device) self._set_egl_device(self.device) return self def _init_env(self, seed: Optional[int] = None) -> Optional[int]: seed = self.set_seed(seed) return seed def _set_seed(self, _seed: Optional[int]) -> Optional[int]: from dm_control.suite.wrappers import pixels if _seed is None: return None random_state = np.random.RandomState(_seed) if isinstance(self._env, pixels.Wrapper): if not hasattr(self._env._env.task, "_random"): raise RuntimeError("self._env._env.task._random does not exist") self._env._env.task._random = random_state else: if not hasattr(self._env.task, "_random"): raise RuntimeError("self._env._env.task._random does not exist") self._env.task._random = random_state self.reset() return _seed def _output_transform( self, timestep_tuple: Tuple["TimeStep"] # noqa: F821 ) -> Tuple[np.ndarray, float, bool, bool, dict]: if type(timestep_tuple) is not tuple: timestep_tuple = (timestep_tuple,) reward = timestep_tuple[0].reward done = truncated = terminated = False # dm_control envs are non-terminating observation = timestep_tuple[0].observation info = {} return observation, reward, terminated, truncated, done, info def _reset_output_transform(self, reset_data): ( observation, reward, terminated, truncated, done, info, ) = self._output_transform(reset_data) return observation, info def __repr__(self) -> str: return ( f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" )
[docs]class DMControlEnv(DMControlWrapper): """DeepMind Control lab environment wrapper. Args: env_name (str): name of the environment task_name (str): name of the task seed (int, optional): seed to use for the environment from_pixels (bool, optional): if ``True``, the observation will be returned as an image. Default is False. Examples: >>> env = DMControlEnv(env_name="cheetah", task_name="run", ... from_pixels=True, frame_skip=4) >>> td = env.rand_step() >>> print(td) >>> print(env.available_envs) """ def __init__(self, env_name, task_name, **kwargs): if not _has_dmc: raise ImportError( "dm_control python package was not found. Please install this dependency." ) kwargs["env_name"] = env_name kwargs["task_name"] = task_name super().__init__(**kwargs) def _build_env( self, env_name: str, task_name: str, _seed: Optional[int] = None, **kwargs, ): from dm_control import suite self.env_name = env_name self.task_name = task_name from_pixels = kwargs.get("from_pixels") if "from_pixels" in kwargs: del kwargs["from_pixels"] pixels_only = kwargs.get("pixels_only") if "pixels_only" in kwargs: del kwargs["pixels_only"] if not _has_dmc: raise ImportError( f"dm_control not found, unable to create {env_name}:" f" {task_name}. Consider downloading and installing " f"dm_control from {self.git_url}" ) if _seed is not None: random_state = np.random.RandomState(_seed) kwargs = {"random": random_state} camera_id = kwargs.pop("camera_id", 0) env = suite.load(env_name, task_name, task_kwargs=kwargs) return super()._build_env( env, from_pixels=from_pixels, pixels_only=pixels_only, camera_id=camera_id, **kwargs, ) def rebuild_with_kwargs(self, **new_kwargs): self._constructor_kwargs.update(new_kwargs) self._env = self._build_env() self._make_specs(self._env) def _check_kwargs(self, kwargs: Dict): if "env_name" in kwargs: env_name = kwargs["env_name"] if "task_name" in kwargs: task_name = kwargs["task_name"] available_envs = dict(self.available_envs) if ( env_name not in available_envs or task_name not in available_envs[env_name] ): raise RuntimeError( f"{env_name} with task {task_name} is unknown in {self.libname}" ) else: raise TypeError("dm_control requires task_name to be specified") else: raise TypeError("dm_control requires env_name to be specified") def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.env_name}, task={self.task_name}, batch_size={self.batch_size})"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources