Shortcuts

Source code for torchrl.record.recorder

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from copy import copy
from typing import Optional, Sequence

import torch

from tensordict.tensordict import TensorDictBase

from tensordict.utils import NestedKey

from torchrl.envs.transforms import ObservationTransform, Transform
from torchrl.record.loggers import Logger

try:
    from torchvision.transforms.functional import center_crop as center_crop_fn
    from torchvision.utils import make_grid
except ImportError:
    center_crop_fn = None


[docs]class VideoRecorder(ObservationTransform): """Video Recorder transform. Will record a series of observations from an environment and write them to a Logger object when needed. Args: logger (Logger): a Logger instance where the video should be written. tag (str): the video tag in the logger. in_keys (Sequence of NestedKey, optional): keys to be read to produce the video. Default is :obj:`"pixels"`. skip (int): frame interval in the output video. Default is 2. center_crop (int, optional): value of square center crop. make_grid (bool, optional): if ``True``, a grid is created assuming that a tensor of shape [B x W x H x 3] is provided, with B being the batch size. Default is True. out_keys (sequence of NestedKey, optional): destination keys. Defaults to ``in_keys`` if not provided. """ def __init__( self, logger: Logger, tag: str, in_keys: Optional[Sequence[NestedKey]] = None, skip: int = 2, center_crop: Optional[int] = None, make_grid: bool = True, out_keys: Optional[Sequence[NestedKey]] = None, **kwargs, ) -> None: if in_keys is None: in_keys = ["pixels"] if out_keys is None: out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) video_kwargs = {"fps": 6} video_kwargs.update(kwargs) self.video_kwargs = video_kwargs self.iter = 0 self.skip = skip self.logger = logger self.tag = tag self.count = 0 self.center_crop = center_crop self.make_grid = make_grid if center_crop and not center_crop_fn: raise ImportError( "Could not load center_crop from torchvision. Make sure torchvision is installed." ) self.obs = [] def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: if not (observation.shape[-1] == 3 or observation.ndimension() == 2): raise RuntimeError(f"Invalid observation shape, got: {observation.shape}") observation_trsf = observation.clone() self.count += 1 if self.count % self.skip == 0: if observation.ndimension() == 2: observation_trsf = observation.unsqueeze(-3) else: if observation_trsf.shape[-1] != 3: raise RuntimeError( "observation_trsf is expected to have 3 dimensions, " f"got {observation_trsf.ndimension()} instead" ) trailing_dim = range(observation_trsf.ndimension() - 3) observation_trsf = observation_trsf.permute(*trailing_dim, -1, -3, -2) if self.center_crop: if center_crop_fn is None: raise ImportError( "Could not import torchvision, `center_crop` not available." "Make sure torchvision is installed in your environment." ) observation_trsf = center_crop_fn( observation_trsf, [self.center_crop, self.center_crop] ) if self.make_grid and observation_trsf.ndimension() == 4: if make_grid is None: raise ImportError( "Could not import torchvision, `make_grid` not available." "Make sure torchvision is installed in your environment." ) observation_trsf = make_grid(observation_trsf) self.obs.append(observation_trsf.to(torch.uint8)) return observation def dump(self, suffix: Optional[str] = None) -> None: """Writes the video to the self.logger attribute. Args: suffix (str, optional): a suffix for the video to be recorded """ if suffix is None: tag = self.tag else: tag = "_".join([self.tag, suffix]) obs = torch.stack(self.obs, 0).unsqueeze(0).cpu() del self.obs if self.logger is not None: self.logger.log_video( name=tag, video=obs, step=self.iter, **self.video_kwargs, ) del obs self.iter += 1 self.count = 0 self.obs = [] def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: self._call(tensordict_reset) return tensordict_reset
[docs]class TensorDictRecorder(Transform): """TensorDict recorder. When the 'dump' method is called, this class will save a stack of the tensordict resulting from :obj:`env.step(td)` in a file with a prefix defined by the out_file_base argument. Args: out_file_base (str): a string defining the prefix of the file where the tensordict will be written. skip_reset (bool): if ``True``, the first TensorDict of the list will be discarded (usually the tensordict resulting from the call to :obj:`env.reset()`) default: True skip (int): frame interval for the saved tensordict. default: 4 """ def __init__( self, out_file_base: str, skip_reset: bool = True, skip: int = 4, in_keys: Optional[Sequence[str]] = None, ) -> None: if in_keys is None: in_keys = [] super().__init__(in_keys=in_keys) self.iter = 0 self.out_file_base = out_file_base self.td = [] self.skip_reset = skip_reset self.skip = skip self.count = 0 def _call(self, tensordict: TensorDictBase) -> TensorDictBase: self.count += 1 if self.count % self.skip == 0: _td = tensordict if self.in_keys: _td = tensordict.select(*self.in_keys).to_tensordict() self.td.append(_td) return tensordict def dump(self, suffix: Optional[str] = None) -> None: if suffix is None: tag = self.tag else: tag = "_".join([self.tag, suffix]) td = self.td if self.skip_reset: td = td[1:] torch.save( torch.stack(td, 0).contiguous(), f"{tag}_tensordict.t", ) self.iter += 1 self.count = 0 del self.td self.td = [] def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: self._call(tensordict_reset) return tensordict_reset

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources