Shortcuts

Source code for torchrl.trainers.helpers.envs

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from copy import copy
from dataclasses import dataclass, field as dataclass_field
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import torch

from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.envs import ParallelEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import env_creator, EnvCreator
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
    CatFrames,
    CatTensors,
    CenterCrop,
    Compose,
    DoubleToFloat,
    GrayScale,
    NoopResetEnv,
    ObservationNorm,
    Resize,
    RewardScaling,
    ToTensorImage,
    TransformedEnv,
    VecNorm,
)
from torchrl.envs.transforms.transforms import (
    FlattenObservation,
    gSDENoise,
    InitTracker,
    StepCounter,
)
from torchrl.record.loggers import Logger
from torchrl.record.recorder import VideoRecorder

LIBS = {
    "gym": GymEnv,
    "dm_control": DMControlEnv,
}


[docs]def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821 """Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targetting a total number of frames of 1M but actually collecting frame_skip * 1M frames. Args: cfg (DictConfig): DictConfig containing some frame-counting argument, including: "max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames", "init_random_frames", "init_env_steps" Returns: the input DictConfig, modified in-place. """ # Adapt all frame counts wrt frame_skip if cfg.frame_skip != 1: fields = [ "max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames", "init_random_frames", "init_env_steps", "noops", ] for field in fields: if hasattr(cfg, field): setattr(cfg, field, getattr(cfg, field) // cfg.frame_skip) return cfg
def make_env_transforms( env, cfg, video_tag, logger, env_name, stats, norm_obs_only, env_library, action_dim_gsde, state_dim_gsde, batch_dims=0, obs_norm_state_dict=None, ): """Creates the typical transforms for and env.""" env = TransformedEnv(env) from_pixels = cfg.from_pixels vecnorm = cfg.vecnorm norm_rewards = vecnorm and cfg.norm_rewards _norm_obs_only = norm_obs_only or not norm_rewards reward_scaling = cfg.reward_scaling reward_loc = cfg.reward_loc if len(video_tag): center_crop = cfg.center_crop if center_crop: center_crop = center_crop[0] env.append_transform( VideoRecorder( logger=logger, tag=f"{video_tag}_{env_name}_video", center_crop=center_crop, ), ) if from_pixels: if not cfg.catframes: raise RuntimeError( "this env builder currently only accepts positive catframes values" "when pixels are being used." ) env.append_transform(ToTensorImage()) if cfg.center_crop: env.append_transform(CenterCrop(*cfg.center_crop)) env.append_transform(Resize(cfg.image_size, cfg.image_size)) if cfg.grayscale: env.append_transform(GrayScale()) env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True)) env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3)) if stats is None and obs_norm_state_dict is None: obs_stats = {} elif stats is None: obs_stats = copy(obs_norm_state_dict) else: obs_stats = copy(stats) obs_stats["standard_normal"] = True obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"]) env.append_transform(obs_norm) if norm_rewards: reward_scaling = 1.0 reward_loc = 0.0 if norm_obs_only: reward_scaling = 1.0 reward_loc = 0.0 if reward_scaling is not None: env.append_transform(RewardScaling(reward_loc, reward_scaling)) if not from_pixels: selected_keys = [ key for key in env.observation_spec.keys(True, True) if ("pixels" not in key) and (key not in env.state_spec.keys(True, True)) ] # even if there is a single tensor, it'll be renamed in "observation_vector" out_key = "observation_vector" env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key)) if not vecnorm: if stats is None and obs_norm_state_dict is None: _stats = {} elif stats is None: _stats = copy(obs_norm_state_dict) else: _stats = copy(stats) _stats.update({"standard_normal": True}) obs_norm = ObservationNorm( **_stats, in_keys=[out_key], ) env.append_transform(obs_norm) else: env.append_transform( VecNorm( in_keys=[out_key, "reward"] if not _norm_obs_only else [out_key], decay=0.9999, ) ) env.append_transform(DoubleToFloat()) if hasattr(cfg, "catframes") and cfg.catframes: env.append_transform(CatFrames(N=cfg.catframes, in_keys=[out_key], dim=-1)) else: env.append_transform(DoubleToFloat()) if hasattr(cfg, "gSDE") and cfg.gSDE: env.append_transform( gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde) ) env.append_transform(StepCounter()) env.append_transform(InitTracker()) return env def get_norm_state_dict(env): """Gets the normalization loc and scale from the env state_dict.""" sd = env.state_dict() sd = { key: val for key, val in sd.items() if key.endswith("loc") or key.endswith("scale") } return sd
[docs]def transformed_env_constructor( cfg: "DictConfig", # noqa: F821 video_tag: str = "", logger: Optional[Logger] = None, stats: Optional[dict] = None, norm_obs_only: bool = False, use_env_creator: bool = False, custom_env_maker: Optional[Callable] = None, custom_env: Optional[EnvBase] = None, return_transformed_envs: bool = True, action_dim_gsde: Optional[int] = None, state_dim_gsde: Optional[int] = None, batch_dims: Optional[int] = 0, obs_norm_state_dict: Optional[dict] = None, ) -> Union[Callable, EnvCreator]: """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. Args: cfg (DictConfig): a DictConfig containing the arguments of the script. video_tag (str, optional): video tag to be passed to the Logger object logger (Logger, optional): logger associated with the script stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online. Default is `False`. use_env_creator (bool, optional): wheter the `EnvCreator` class should be used. By using `EnvCreator`, one can make sure that running statistics will be put in shared memory and accessible for all workers when using a `VecNorm` transform. Default is `True`. custom_env_maker (callable, optional): if your env maker is not part of torchrl env wrappers, a custom callable can be passed instead. In this case it will override the constructor retrieved from `args`. custom_env (EnvBase, optional): if an existing environment needs to be transformed_in, it can be passed directly to this helper. `custom_env_maker` and `custom_env` are exclusive features. return_transformed_envs (bool, optional): if ``True``, a transformed_in environment is returned. action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise. Make sure this is indicated in environment executed in parallel. state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise. Make sure this is indicated in environment executed in parallel. batch_dims (int, optional): number of dimensions of a batch of data. If a single env is used, it should be 0 (default). If multiple envs are being transformed in parallel, it should be set to 1 (or the number of dims of the batch). obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the environment """ def make_transformed_env(**kwargs) -> TransformedEnv: env_name = cfg.env_name env_task = cfg.env_task env_library = LIBS[cfg.env_library] frame_skip = cfg.frame_skip from_pixels = cfg.from_pixels categorical_action_encoding = cfg.categorical_action_encoding if custom_env is None and custom_env_maker is None: if isinstance(cfg.collector_device, str): device = cfg.collector_device elif isinstance(cfg.collector_device, Sequence): device = cfg.collector_device[0] else: raise ValueError( "collector_device must be either a string or a sequence of strings" ) env_kwargs = { "env_name": env_name, "device": device, "frame_skip": frame_skip, "from_pixels": from_pixels or len(video_tag), "pixels_only": from_pixels, } if env_library is GymEnv: env_kwargs.update( {"categorical_action_encoding": categorical_action_encoding} ) elif categorical_action_encoding: raise NotImplementedError( "categorical_action_encoding=True is currently only compatible with GymEnvs." ) if env_library is DMControlEnv: env_kwargs.update({"task_name": env_task}) env_kwargs.update(kwargs) env = env_library(**env_kwargs) elif custom_env is None and custom_env_maker is not None: env = custom_env_maker(**kwargs) elif custom_env_maker is None and custom_env is not None: env = custom_env else: raise RuntimeError("cannot provive both custom_env and custom_env_maker") if cfg.noops and custom_env is None: # this is a bit hacky: if custom_env is not None, it is probably a ParallelEnv # that already has its NoopResetEnv set for the contained envs. # There is a risk however that we're just skipping the NoopsReset instantiation env = TransformedEnv(env, NoopResetEnv(cfg.noops)) if not return_transformed_envs: return env return make_env_transforms( env, cfg, video_tag, logger, env_name, stats, norm_obs_only, env_library, action_dim_gsde, state_dim_gsde, batch_dims=batch_dims, obs_norm_state_dict=obs_norm_state_dict, ) if use_env_creator: return env_creator(make_transformed_env) return make_transformed_env
[docs]def parallel_env_constructor( cfg: "DictConfig", **kwargs # noqa: F821 ) -> Union[ParallelEnv, EnvCreator]: """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. Args: cfg (DictConfig): config containing user-defined arguments kwargs: keyword arguments for the `transformed_env_constructor` method. """ batch_transform = cfg.batch_transform if not batch_transform: raise NotImplementedError( "batch_transform must be set to True for the recorder to be synced " "with the collection envs." ) if cfg.env_per_collector == 1: kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor(**kwargs) return make_transformed_env kwargs.update({"cfg": cfg, "use_env_creator": True}) make_transformed_env = transformed_env_constructor( return_transformed_envs=not batch_transform, **kwargs ) parallel_env = ParallelEnv( num_workers=cfg.env_per_collector, create_env_fn=make_transformed_env, create_env_kwargs=None, pin_memory=cfg.pin_memory, ) if batch_transform: kwargs.update( { "cfg": cfg, "use_env_creator": False, "custom_env": parallel_env, "batch_dims": 1, } ) env = transformed_env_constructor(**kwargs)() return env return parallel_env
[docs]@torch.no_grad() def get_stats_random_rollout( cfg: "DictConfig", # noqa: F821 proof_environment: EnvBase = None, key: Optional[str] = None, ): """Gathers stas (loc and scale) from an environment using random rollouts. Args: cfg (DictConfig): a config object with `init_env_steps` field, indicating the total number of frames to be collected to compute the stats. proof_environment (EnvBase instance, optional): if provided, this env will be used ot execute the rollouts. If not, it will be created using the cfg object. key (str, optional): if provided, the stats of this key will be gathered. If not, it is expected that only one key exists in `env.observation_spec`. """ proof_env_is_none = proof_environment is None if proof_env_is_none: proof_environment = transformed_env_constructor( cfg=cfg, use_env_creator=False, stats={"loc": 0.0, "scale": 1.0} )() if VERBOSE: torchrl_logger.info("computing state stats") if not hasattr(cfg, "init_env_steps"): raise AttributeError("init_env_steps missing from arguments.") n = 0 val_stats = [] while n < cfg.init_env_steps: _td_stats = proof_environment.rollout(max_steps=cfg.init_env_steps) n += _td_stats.numel() val = _td_stats.get(key).cpu() val_stats.append(val) del _td_stats, val val_stats = torch.cat(val_stats, 0) if key is None: keys = list(proof_environment.observation_spec.keys(True, True)) key = keys.pop() if len(keys): raise RuntimeError( f"More than one key exists in the observation_specs: {[key] + keys} were found, " "thus get_stats_random_rollout cannot infer which to compute the stats of." ) if key == "pixels": m = val_stats.mean() s = val_stats.std() else: m = val_stats.mean(dim=0) s = val_stats.std(dim=0) m[s == 0] = 0.0 s[s == 0] = 1.0 if VERBOSE: torchrl_logger.info( f"stats computed for {val_stats.numel()} steps. Got: \n" f"loc = {m}, \n" f"scale = {s}" ) if not torch.isfinite(m).all(): raise RuntimeError("non-finite values found in mean") if not torch.isfinite(s).all(): raise RuntimeError("non-finite values found in sd") stats = {"loc": m, "scale": s} if proof_env_is_none: proof_environment.close() if ( proof_environment.device != torch.device("cpu") and torch.cuda.device_count() > 0 ): torch.cuda.empty_cache() del proof_environment return stats
def initialize_observation_norm_transforms( proof_environment: EnvBase, num_iter: int = 1000, key: Union[str, Tuple[str, ...]] = None, ): """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`. If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op. Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect. If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will be raised. Args: proof_environment (EnvBase instance, optional): if provided, this env will be used ot execute the rollouts. If not, it will be created using the cfg object. num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms` key (str, optional): if provided, the stats of this key will be gathered. If not, it is expected that only one key exists in `env.observation_spec`. """ if not isinstance(proof_environment.transform, Compose) and not isinstance( proof_environment.transform, ObservationNorm ): return if key is None: keys = list(proof_environment.base_env.observation_spec.keys(True, True)) key = keys.pop() if len(keys): raise RuntimeError( f"More than one key exists in the observation_specs: {[key] + keys} were found, " "thus initialize_observation_norm_transforms cannot infer which to compute the stats of." ) if isinstance(proof_environment.transform, Compose): for transform in proof_environment.transform: if isinstance(transform, ObservationNorm) and not transform.initialized: transform.init_stats(num_iter=num_iter, key=key) elif not proof_environment.transform.initialized: proof_environment.transform.init_stats(num_iter=num_iter, key=key) def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv): """Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts. Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list Args: proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm` state dict from """ obs_norm_state_dicts = [] if isinstance(proof_environment.transform, Compose): for idx, transform in enumerate(proof_environment.transform): if isinstance(transform, ObservationNorm): obs_norm_state_dicts.append((idx, transform.state_dict())) if isinstance(proof_environment.transform, ObservationNorm): obs_norm_state_dicts.append((0, proof_environment.transform.state_dict())) return obs_norm_state_dicts @dataclass class EnvConfig: """Environment config struct.""" env_library: str = "gym" # env_library used for the simulated environment. Default=gym env_name: str = "Humanoid-v2" # name of the environment to be created. Default=Humanoid-v2 env_task: str = "" # task (if any) for the environment. Default=run from_pixels: bool = False # whether the environment output should be state vector(s) (default) or the pixels. frame_skip: int = 1 # frame_skip for the environment. Note that this value does NOT impact the buffer size, # maximum steps per trajectory, frames per batch or any other factor in the algorithm, # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4 # the actual number of frames retrieved will be 200e6. Default=1. reward_scaling: Optional[float] = None # scale of the reward. reward_loc: float = 0.0 # location of the reward. init_env_steps: int = 1000 # number of random steps to compute normalizing constants vecnorm: bool = False # Normalizes the environment observation and reward outputs with the running statistics obtained across processes. norm_rewards: bool = False # If True, rewards will be normalized on the fly. This may interfere with SAC update rule and should be used cautiously. norm_stats: bool = True # Deactivates the normalization based on random collection of data. noops: int = 0 # number of random steps to do after reset. Default is 0 catframes: int = 0 # Number of frames to concatenate through time. Default is 0 (do not use CatFrames). center_crop: Any = dataclass_field(default_factory=lambda: []) # center crop size. grayscale: bool = True # Disables grayscale transform. max_frames_per_traj: int = 1000 # Number of steps before a reset of the environment is called (if it has not been flagged as done before). batch_transform: bool = False # if ``True``, the transforms will be applied to the parallel env, and not to each individual env.\ image_size: int = 84 # if True and environment has discrete action space, then it is encoded as categorical values rather than one-hot. categorical_action_encoding: bool = False

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources