Shortcuts

Source code for torchrl.envs.transforms.vc1

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import importlib
import os
import subprocess
from functools import partial

import torch
from tensordict import TensorDictBase
from torch import nn
from torchrl._utils import logger as torchrl_logger

from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec, Unbounded
from torchrl.envs.transforms.transforms import (
    CenterCrop,
    Compose,
    ObservationNorm,
    Resize,
    ToTensorImage,
    Transform,
)
from torchrl.envs.transforms.utils import _set_missing_tolerance

_has_vc = importlib.util.find_spec("vc_models") is not None


[docs]class VC1Transform(Transform): """VC1 Transform class. VC1 provides pre-trained ResNet weights aimed at facilitating visual embedding for robotic tasks. The models are trained using Ego4d. See the paper: VC1: A Universal Visual Representation for Robot Manipulation (Suraj Nair, Aravind Rajeswaran, Vikash Kumar, Chelsea Finn, Abhinav Gupta) https://arxiv.org/abs/2203.12601 The VC1Transform is created in a lazy manner: the object will be initialized only when an attribute (a spec or the forward method) will be queried. The reason for this is that the :obj:`_init()` method requires some attributes of the parent environment (if any) to be accessed: by making the class lazy we can ensure that the following code snippet works as expected: Examples: >>> transform = VC1Transform("default", in_keys=["pixels"]) >>> env.append_transform(transform) >>> # the forward method will first call _init which will look at env.observation_spec >>> env.reset() Args: in_keys (list of NestedKeys): list of input keys. If left empty, the "pixels" key is assumed. out_keys (list of NestedKeys, optional): list of output keys. If left empty, "VC1_vec" is assumed. model_name (str): One of ``"large"``, ``"base"`` or any other compatible model name (see the `github repo <https://github.com/facebookresearch/eai-vc>`_ for more info). Defaults to ``"default"`` which provides a small, untrained model for testing. del_keys (bool, optional): If ``True`` (default), the input key will be discarded from the returned tensordict. """ inplace = False IMPORT_ERROR = ( "Could not load vc_models. You can install it via " "VC1Transform.install_vc_models()." ) def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True): if model_name == "default": self.make_noload_model() model_name = "vc1_vitb_noload" self.model_name = model_name self.del_keys = del_keys super().__init__(in_keys=in_keys, out_keys=out_keys) self._init() def _init(self): try: from vc_models.models.vit import model_utils except ModuleNotFoundError as err: raise ModuleNotFoundError(self.IMPORT_ERROR) from err if self.model_name == "base": model_name = model_utils.VC1_BASE_NAME elif self.model_name == "large": model_name = model_utils.VC1_LARGE_NAME else: model_name = self.model_name model, embd_size, model_transforms, model_info = model_utils.load_model( model_name ) self.model = model self.embd_size = embd_size self.model_transforms = self._map_tv_to_torchrl(model_transforms) def _map_tv_to_torchrl( self, model_transforms, in_keys=None, ): if in_keys is None: in_keys = self.in_keys from torchvision import transforms if isinstance(model_transforms, transforms.Resize): size = model_transforms.size if isinstance(size, int): size = (size, size) return Resize( *size, in_keys=in_keys, ) elif isinstance(model_transforms, transforms.CenterCrop): size = model_transforms.size if isinstance(size, int): size = (size,) return CenterCrop( *size, in_keys=in_keys, ) elif isinstance(model_transforms, transforms.Normalize): return ObservationNorm( in_keys=in_keys, loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1), scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1), standard_normal=True, ) elif isinstance(model_transforms, transforms.ToTensor): return ToTensorImage( in_keys=in_keys, ) elif isinstance(model_transforms, transforms.Compose): transform_list = [] for t in model_transforms.transforms: if isinstance(t, transforms.ToTensor): transform_list.insert(0, t) else: transform_list.append(t) if len(transform_list) == 0: raise RuntimeError("Did not find any transform.") for i, t in enumerate(transform_list): if i == 0: transform_list[i] = self._map_tv_to_torchrl(t) else: transform_list[i] = self._map_tv_to_torchrl(t) return Compose(*transform_list) else: raise NotImplementedError(type(model_transforms)) def _call(self, next_tensordict): if not self.del_keys: in_keys = [ in_key for in_key, out_key in zip(self.in_keys, self.out_keys) if in_key != out_key ] saved_td = next_tensordict.select(*in_keys) with next_tensordict.view(-1) as tensordict_view: super()._call(self.model_transforms(tensordict_view)) if self.del_keys: next_tensordict.exclude(*self.in_keys, inplace=True) else: # reset in_keys next_tensordict.update(saved_td) return next_tensordict forward = _call def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: # TODO: Check this makes sense with _set_missing_tolerance(self, True): tensordict_reset = self._call(tensordict_reset) return tensordict_reset @torch.no_grad() def _apply_transform(self, obs: torch.Tensor) -> None: shape = None if obs.ndimension() > 4: shape = obs.shape[:-3] obs = obs.flatten(0, -4) out = self.model(obs) if shape is not None: out = out.view(*shape, *out.shape[1:]) return out
[docs] def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: if not isinstance(observation_spec, Composite): raise ValueError("VC1Transform can only infer Composite") keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys] device = observation_spec[keys[0]].device dim = observation_spec[keys[0]].shape[:-3] observation_spec = observation_spec.clone() if self.del_keys: for in_key in keys: del observation_spec[in_key] for out_key in self.out_keys: observation_spec[out_key] = Unbounded( shape=torch.Size([*dim, self.embd_size]), device=device ) return observation_spec
[docs] def to(self, dest: DEVICE_TYPING | torch.dtype): if isinstance(dest, torch.dtype): self._dtype = dest else: self._device = dest return super().to(dest)
@property def device(self): return self._device @property def dtype(self): return self._dtype @classmethod def install_vc_models(cls, auto_exit=False): try: from vc_models import models # noqa: F401 torchrl_logger.info("vc_models found, no need to install.") except ModuleNotFoundError: HOME = os.environ.get("HOME") vcdir = HOME + "/.cache/torchrl/eai-vc" parentdir = os.path.dirname(os.path.abspath(vcdir)) os.makedirs(parentdir, exist_ok=True) try: from git import Repo except ModuleNotFoundError as err: raise ModuleNotFoundError( "Could not load git. Make sure that `git` has been installed " "in your virtual environment." ) from err Repo.clone_from("https://github.com/facebookresearch/eai-vc.git", vcdir) os.chdir(vcdir + "/vc_models") subprocess.call(["python", "setup.py", "develop"]) if not auto_exit: input( "VC1 has been successfully installed. Exit this python run and " "relaunch it again. Press Enter to exit..." ) exit()
[docs] @classmethod def make_noload_model(cls): """Creates an naive model at a custom destination.""" import vc_models models_filepath = os.path.dirname(os.path.abspath(vc_models.__file__)) cfg_path = os.path.join( models_filepath, "conf", "model", "vc1_vitb_noload.yaml" ) if os.path.exists(cfg_path): return config = """_target_: vc_models.models.load_model model: _target_: vc_models.models.vit.vit.load_mae_encoder checkpoint_path: model: _target_: torchrl.envs.transforms.vc1._vit_base_patch16 img_size: 224 use_cls: True drop_path_rate: 0.0 transform: _target_: vc_models.transforms.vit_transforms metadata: algo: mae model: vit_base_patch16 data: - ego - imagenet - inav comment: 182_epochs """ with open(cfg_path, "w") as file: file.write(config)
def _vit_base_patch16(**kwargs): from vc_models.models.vit.vit import VisionTransformer model = VisionTransformer( patch_size=16, embed_dim=16, depth=4, num_heads=4, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ) return model

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources