Shortcuts

Source code for torchrl.envs.libs.habitat

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import importlib.util

import torch
from torchrl._utils import _make_ordinal_device
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import _classproperty

_has_habitat = importlib.util.find_spec("habitat") is not None


def _wrap_import_error(fun):
    @functools.wraps(fun)
    def new_fun(*args, **kwargs):
        if not _has_habitat:
            raise ImportError(
                "Habitat could not be loaded. Consider installing "
                "it or solving the import bugs (see attached error message). "
                "Refer to TorchRL's knowledge base in the documentation to "
                "debug habitat installation."
            )
        return fun(*args, **kwargs)

    return new_fun


@_wrap_import_error
def _get_available_envs():
    for env in GymEnv.available_envs:
        if env.startswith("Habitat"):
            yield env


[docs]class HabitatEnv(GymEnv): """A wrapper for habitat envs. This class currently serves as placeholder and compatibility security. It behaves exactly like the GymEnv wrapper. Doc: https://aihabitat.org/docs/ GitHub: https://github.com/facebookresearch/habitat-lab URL: https://aihabitat.org/habitat3/ Paper: https://ai.meta.com/static-resource/habitat3 Args: env_name (str): The environment to execute. categorical_action_encoding (bool, optional): if ``True``, categorical specs will be converted to the TorchRL equivalent (:class:`torchrl.data.Categorical`), otherwise a one-hot encoding will be used (:class:`torchrl.data.OneHot`). Defaults to ``False``. Keyword Args: from_pixels (bool, optional): if ``True``, an attempt to return the pixel observations from the env will be performed. By default, these observations will be written under the ``"pixels"`` entry. The method being used varies depending on the gym version and may involve a ``wrappers.pixel_observation.PixelObservationWrapper``. Defaults to ``False``. pixels_only (bool, optional): if ``True``, only the pixel observations will be returned (by default under the ``"pixels"`` entry in the output tensordict). If ``False``, observations (eg, states) and pixels will be returned whenever ``from_pixels=True``. Defaults to ``True``. frame_skip (int, optional): if provided, indicates for how many steps the same action is to be repeated. The observation returned will be the last observation of the sequence, whereas the reward will be the sum of rewards across steps. device (torch.device, optional): if provided, the device on which the simulation will occur. Defaults to ``torch.device("cuda:0")``. batch_size (torch.Size, optional): the batch size of the environment. Should match the leading dimensions of all observations, done states, rewards, actions and infos. Defaults to ``torch.Size([])``. allow_done_after_reset (bool, optional): if ``True``, it is tolerated for envs to be ``done`` just after :meth:`~.reset` is called. Defaults to ``False``. Attributes: available_envs (List[str]): a list of environments to build. Examples: >>> from torchrl.envs import HabitatEnv >>> env = HabitatEnv("HabitatRenderPick-v0", from_pixels=True) >>> env.rollout(3) """ @_wrap_import_error @set_gym_backend("gym") def __init__(self, env_name, **kwargs): import habitat # noqa import habitat.gym # noqa device_num = torch.device(kwargs.pop("device", 0)).index kwargs["override_options"] = [ f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}", "habitat.simulator.concur_render=False", ] super().__init__(env_name=env_name, **kwargs) @_classproperty def available_envs(cls): if not _has_habitat: return [] return list(_get_available_envs()) def _build_gym_env(self, env, pixels_only): if self.from_pixels: env.reset() return super()._build_gym_env(env, pixels_only) def to(self, device: DEVICE_TYPING) -> EnvBase: device = _make_ordinal_device(torch.device(device)) if device.type != "cuda": raise ValueError("The device must be of type cuda for Habitat.") device_num = device.index kwargs = {"override_options": []} for arg in self._constructor_kwargs.get("override_options", []): if arg.startswith("habitat.simulator.habitat_sim_v0.gpu_device_id"): arg = f"habitat.simulator.habitat_sim_v0.gpu_device_id={device_num}" kwargs["override_options"].append(arg) else: kwargs["override_options"].append(arg) self._env.close() del self._env self.rebuild_with_kwargs(**kwargs) return super().to(device)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources