Shortcuts

PixelRenderTransform

torchrl.record.PixelRenderTransform(out_keys: List[NestedKey] = None, preproc: Callable[[np.ndarray | torch.Tensor], np.ndarray | torch.Tensor] = None, as_non_tensor: bool = None, render_method: str = 'render', **kwargs) None[source]

A transform to call render on the parent environment and register the pixel observation in the tensordict.

This transform offers an alternative to the from_pixels syntatic sugar when instantiating an environment that offers rendering is expensive, or when from_pixels is not implemented. It can be used within a single environment or over batched environments alike.

Parameters:
  • out_keys (List[NestedKey] or Nested) – List of keys where to register the pixel observations.

  • preproc (Callable, optional) – a preproc function. Can be used to reshape the observation, or apply any other transformation that makes it possible to register it in the output data.

  • as_non_tensor (bool, optional) – if True, the data will be written as a NonTensorData thereby relaxing the shape requirements. If not provided, it will be inferred automatically from the input data type and shape.

  • render_method (str, optional) – the name of the render method. Defaults to "render".

  • **kwargs – additional keyword arguments to pass to the render function (e.g. mode="rgb_array").

Examples

>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> def make_env():
>>>     env = GymEnv("CartPole-v1", render_mode="rgb_array")
>>>     env = env.append_transform(PixelRenderTransform())
>>>     return env
>>>
>>> if __name__ == "__main__":
...     logger = CSVLogger("dummy", video_format="mp4")
...
...     env = ParallelEnv(4, EnvCreator(make_env))
...
...     env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
...     env.rollout(3)
...
...     check_env_specs(env)
...
...     r = env.rollout(30)
...     print(env)
...     env.transform.dump()
...     env.close()

This transform can also be used whenever a batched environment render() returns a single image:

Examples

>>> from torchrl.envs import check_env_specs
>>> from torchrl.envs.libs.vmas import VmasEnv
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> env = VmasEnv(
...     scenario="flocking",
...     num_envs=32,
...     continuous_actions=True,
...     max_steps=200,
...     device="cpu",
...     seed=None,
...     # Scenario kwargs
...     n_agents=5,
... )
>>>
>>> logger = CSVLogger("dummy", video_format="mp4")
>>>
>>> env = env.append_transform(PixelRenderTransform(mode="rgb_array", preproc=lambda x: x.copy()))
>>> env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
>>>
>>> check_env_specs(env)
>>>
>>> r = env.rollout(30)
>>> env.transform[-1].dump()

The transform can be disabled using the switch() method, which will turn the rendering on if it’s off or off if it’s on (an argument can also be passed to control this behaviour). Since transforms are Module instances, apply() can be used to control this behaviour:

>>> def switch(module):
...     if isinstance(module, PixelRenderTransform):
...         module.switch()
>>> env.apply(switch)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources