Source code for torchrl.trainers.helpers.envs
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from copy import copy
from dataclasses import dataclass, field as dataclass_field
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import torch
from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.envs import ParallelEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import env_creator, EnvCreator
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
CatFrames,
CatTensors,
CenterCrop,
Compose,
DoubleToFloat,
GrayScale,
NoopResetEnv,
ObservationNorm,
Resize,
RewardScaling,
ToTensorImage,
TransformedEnv,
VecNorm,
)
from torchrl.envs.transforms.transforms import (
FlattenObservation,
gSDENoise,
InitTracker,
StepCounter,
)
from torchrl.record.loggers import Logger
from torchrl.record.recorder import VideoRecorder
LIBS = {
"gym": GymEnv,
"dm_control": DMControlEnv,
}
[docs]def correct_for_frame_skip(cfg: "DictConfig") -> "DictConfig": # noqa: F821
"""Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip.
This is aimed at avoiding unknowingly over-sampling from the environment, i.e. targetting a total number of frames
of 1M but actually collecting frame_skip * 1M frames.
Args:
cfg (DictConfig): DictConfig containing some frame-counting argument, including:
"max_frames_per_traj", "total_frames", "frames_per_batch", "record_frames", "annealing_frames",
"init_random_frames", "init_env_steps"
Returns:
the input DictConfig, modified in-place.
"""
# Adapt all frame counts wrt frame_skip
if cfg.frame_skip != 1:
fields = [
"max_frames_per_traj",
"total_frames",
"frames_per_batch",
"record_frames",
"annealing_frames",
"init_random_frames",
"init_env_steps",
"noops",
]
for field in fields:
if hasattr(cfg, field):
setattr(cfg, field, getattr(cfg, field) // cfg.frame_skip)
return cfg
def make_env_transforms(
env,
cfg,
video_tag,
logger,
env_name,
stats,
norm_obs_only,
env_library,
action_dim_gsde,
state_dim_gsde,
batch_dims=0,
obs_norm_state_dict=None,
):
"""Creates the typical transforms for and env."""
env = TransformedEnv(env)
from_pixels = cfg.from_pixels
vecnorm = cfg.vecnorm
norm_rewards = vecnorm and cfg.norm_rewards
_norm_obs_only = norm_obs_only or not norm_rewards
reward_scaling = cfg.reward_scaling
reward_loc = cfg.reward_loc
if len(video_tag):
center_crop = cfg.center_crop
if center_crop:
center_crop = center_crop[0]
env.append_transform(
VideoRecorder(
logger=logger,
tag=f"{video_tag}_{env_name}_video",
center_crop=center_crop,
),
)
if from_pixels:
if not cfg.catframes:
raise RuntimeError(
"this env builder currently only accepts positive catframes values"
"when pixels are being used."
)
env.append_transform(ToTensorImage())
if cfg.center_crop:
env.append_transform(CenterCrop(*cfg.center_crop))
env.append_transform(Resize(cfg.image_size, cfg.image_size))
if cfg.grayscale:
env.append_transform(GrayScale())
env.append_transform(FlattenObservation(0, -3, allow_positive_dim=True))
env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"], dim=-3))
if stats is None and obs_norm_state_dict is None:
obs_stats = {}
elif stats is None:
obs_stats = copy(obs_norm_state_dict)
else:
obs_stats = copy(stats)
obs_stats["standard_normal"] = True
obs_norm = ObservationNorm(**obs_stats, in_keys=["pixels"])
env.append_transform(obs_norm)
if norm_rewards:
reward_scaling = 1.0
reward_loc = 0.0
if norm_obs_only:
reward_scaling = 1.0
reward_loc = 0.0
if reward_scaling is not None:
env.append_transform(RewardScaling(reward_loc, reward_scaling))
if not from_pixels:
selected_keys = [
key
for key in env.observation_spec.keys(True, True)
if ("pixels" not in key) and (key not in env.state_spec.keys(True, True))
]
# even if there is a single tensor, it'll be renamed in "observation_vector"
out_key = "observation_vector"
env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))
if not vecnorm:
if stats is None and obs_norm_state_dict is None:
_stats = {}
elif stats is None:
_stats = copy(obs_norm_state_dict)
else:
_stats = copy(stats)
_stats.update({"standard_normal": True})
obs_norm = ObservationNorm(
**_stats,
in_keys=[out_key],
)
env.append_transform(obs_norm)
else:
env.append_transform(
VecNorm(
in_keys=[out_key, "reward"] if not _norm_obs_only else [out_key],
decay=0.9999,
)
)
env.append_transform(DoubleToFloat())
if hasattr(cfg, "catframes") and cfg.catframes:
env.append_transform(CatFrames(N=cfg.catframes, in_keys=[out_key], dim=-1))
else:
env.append_transform(DoubleToFloat())
if hasattr(cfg, "gSDE") and cfg.gSDE:
env.append_transform(
gSDENoise(action_dim=action_dim_gsde, state_dim=state_dim_gsde)
)
env.append_transform(StepCounter())
env.append_transform(InitTracker())
return env
def get_norm_state_dict(env):
"""Gets the normalization loc and scale from the env state_dict."""
sd = env.state_dict()
sd = {
key: val
for key, val in sd.items()
if key.endswith("loc") or key.endswith("scale")
}
return sd
[docs]def transformed_env_constructor(
cfg: "DictConfig", # noqa: F821
video_tag: str = "",
logger: Optional[Logger] = None,
stats: Optional[dict] = None,
norm_obs_only: bool = False,
use_env_creator: bool = False,
custom_env_maker: Optional[Callable] = None,
custom_env: Optional[EnvBase] = None,
return_transformed_envs: bool = True,
action_dim_gsde: Optional[int] = None,
state_dim_gsde: Optional[int] = None,
batch_dims: Optional[int] = 0,
obs_norm_state_dict: Optional[dict] = None,
) -> Union[Callable, EnvCreator]:
"""Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
Args:
cfg (DictConfig): a DictConfig containing the arguments of the script.
video_tag (str, optional): video tag to be passed to the Logger object
logger (Logger, optional): logger associated with the script
stats (dict, optional): a dictionary containing the :obj:`loc` and :obj:`scale` for the `ObservationNorm` transform
norm_obs_only (bool, optional): If `True` and `VecNorm` is used, the reward won't be normalized online.
Default is `False`.
use_env_creator (bool, optional): wheter the `EnvCreator` class should be used. By using `EnvCreator`,
one can make sure that running statistics will be put in shared memory and accessible for all workers
when using a `VecNorm` transform. Default is `True`.
custom_env_maker (callable, optional): if your env maker is not part
of torchrl env wrappers, a custom callable
can be passed instead. In this case it will override the
constructor retrieved from `args`.
custom_env (EnvBase, optional): if an existing environment needs to be
transformed_in, it can be passed directly to this helper. `custom_env_maker`
and `custom_env` are exclusive features.
return_transformed_envs (bool, optional): if ``True``, a transformed_in environment
is returned.
action_dim_gsde (int, Optional): if gSDE is used, this can present the action dim to initialize the noise.
Make sure this is indicated in environment executed in parallel.
state_dim_gsde: if gSDE is used, this can present the state dim to initialize the noise.
Make sure this is indicated in environment executed in parallel.
batch_dims (int, optional): number of dimensions of a batch of data. If a single env is
used, it should be 0 (default). If multiple envs are being transformed in parallel,
it should be set to 1 (or the number of dims of the batch).
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform to be loaded into the
environment
"""
def make_transformed_env(**kwargs) -> TransformedEnv:
env_name = cfg.env_name
env_task = cfg.env_task
env_library = LIBS[cfg.env_library]
frame_skip = cfg.frame_skip
from_pixels = cfg.from_pixels
categorical_action_encoding = cfg.categorical_action_encoding
if custom_env is None and custom_env_maker is None:
if isinstance(cfg.collector_device, str):
device = cfg.collector_device
elif isinstance(cfg.collector_device, Sequence):
device = cfg.collector_device[0]
else:
raise ValueError(
"collector_device must be either a string or a sequence of strings"
)
env_kwargs = {
"env_name": env_name,
"device": device,
"frame_skip": frame_skip,
"from_pixels": from_pixels or len(video_tag),
"pixels_only": from_pixels,
}
if env_library is GymEnv:
env_kwargs.update(
{"categorical_action_encoding": categorical_action_encoding}
)
elif categorical_action_encoding:
raise NotImplementedError(
"categorical_action_encoding=True is currently only compatible with GymEnvs."
)
if env_library is DMControlEnv:
env_kwargs.update({"task_name": env_task})
env_kwargs.update(kwargs)
env = env_library(**env_kwargs)
elif custom_env is None and custom_env_maker is not None:
env = custom_env_maker(**kwargs)
elif custom_env_maker is None and custom_env is not None:
env = custom_env
else:
raise RuntimeError("cannot provive both custom_env and custom_env_maker")
if cfg.noops and custom_env is None:
# this is a bit hacky: if custom_env is not None, it is probably a ParallelEnv
# that already has its NoopResetEnv set for the contained envs.
# There is a risk however that we're just skipping the NoopsReset instantiation
env = TransformedEnv(env, NoopResetEnv(cfg.noops))
if not return_transformed_envs:
return env
return make_env_transforms(
env,
cfg,
video_tag,
logger,
env_name,
stats,
norm_obs_only,
env_library,
action_dim_gsde,
state_dim_gsde,
batch_dims=batch_dims,
obs_norm_state_dict=obs_norm_state_dict,
)
if use_env_creator:
return env_creator(make_transformed_env)
return make_transformed_env
[docs]def parallel_env_constructor(
cfg: "DictConfig", **kwargs # noqa: F821
) -> Union[ParallelEnv, EnvCreator]:
"""Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
Args:
cfg (DictConfig): config containing user-defined arguments
kwargs: keyword arguments for the `transformed_env_constructor` method.
"""
batch_transform = cfg.batch_transform
if not batch_transform:
raise NotImplementedError(
"batch_transform must be set to True for the recorder to be synced "
"with the collection envs."
)
if cfg.env_per_collector == 1:
kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(**kwargs)
return make_transformed_env
kwargs.update({"cfg": cfg, "use_env_creator": True})
make_transformed_env = transformed_env_constructor(
return_transformed_envs=not batch_transform, **kwargs
)
parallel_env = ParallelEnv(
num_workers=cfg.env_per_collector,
create_env_fn=make_transformed_env,
create_env_kwargs=None,
pin_memory=cfg.pin_memory,
)
if batch_transform:
kwargs.update(
{
"cfg": cfg,
"use_env_creator": False,
"custom_env": parallel_env,
"batch_dims": 1,
}
)
env = transformed_env_constructor(**kwargs)()
return env
return parallel_env
[docs]@torch.no_grad()
def get_stats_random_rollout(
cfg: "DictConfig", # noqa: F821
proof_environment: EnvBase = None,
key: Optional[str] = None,
):
"""Gathers stas (loc and scale) from an environment using random rollouts.
Args:
cfg (DictConfig): a config object with `init_env_steps` field, indicating
the total number of frames to be collected to compute the stats.
proof_environment (EnvBase instance, optional): if provided, this env will
be used ot execute the rollouts. If not, it will be created using
the cfg object.
key (str, optional): if provided, the stats of this key will be gathered.
If not, it is expected that only one key exists in `env.observation_spec`.
"""
proof_env_is_none = proof_environment is None
if proof_env_is_none:
proof_environment = transformed_env_constructor(
cfg=cfg, use_env_creator=False, stats={"loc": 0.0, "scale": 1.0}
)()
if VERBOSE:
torchrl_logger.info("computing state stats")
if not hasattr(cfg, "init_env_steps"):
raise AttributeError("init_env_steps missing from arguments.")
n = 0
val_stats = []
while n < cfg.init_env_steps:
_td_stats = proof_environment.rollout(max_steps=cfg.init_env_steps)
n += _td_stats.numel()
val = _td_stats.get(key).cpu()
val_stats.append(val)
del _td_stats, val
val_stats = torch.cat(val_stats, 0)
if key is None:
keys = list(proof_environment.observation_spec.keys(True, True))
key = keys.pop()
if len(keys):
raise RuntimeError(
f"More than one key exists in the observation_specs: {[key] + keys} were found, "
"thus get_stats_random_rollout cannot infer which to compute the stats of."
)
if key == "pixels":
m = val_stats.mean()
s = val_stats.std()
else:
m = val_stats.mean(dim=0)
s = val_stats.std(dim=0)
m[s == 0] = 0.0
s[s == 0] = 1.0
if VERBOSE:
torchrl_logger.info(
f"stats computed for {val_stats.numel()} steps. Got: \n"
f"loc = {m}, \n"
f"scale = {s}"
)
if not torch.isfinite(m).all():
raise RuntimeError("non-finite values found in mean")
if not torch.isfinite(s).all():
raise RuntimeError("non-finite values found in sd")
stats = {"loc": m, "scale": s}
if proof_env_is_none:
proof_environment.close()
if (
proof_environment.device != torch.device("cpu")
and torch.cuda.device_count() > 0
):
torch.cuda.empty_cache()
del proof_environment
return stats
def initialize_observation_norm_transforms(
proof_environment: EnvBase,
num_iter: int = 1000,
key: Union[str, Tuple[str, ...]] = None,
):
"""Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
If an :obj:`ObservationNorm` already has non-null :obj:`loc` or :obj:`scale`, a call to :obj:`initialize_observation_norm_transforms` will be a no-op.
Similarly, if the transformed environment does not contain any :obj:`ObservationNorm`, a call to this function will have no effect.
If no key is provided but the observations of the :obj:`EnvBase` contains more than one key, an exception will
be raised.
Args:
proof_environment (EnvBase instance, optional): if provided, this env will
be used ot execute the rollouts. If not, it will be created using
the cfg object.
num_iter (int): Number of iterations used for initializing the :obj:`ObservationNorms`
key (str, optional): if provided, the stats of this key will be gathered.
If not, it is expected that only one key exists in `env.observation_spec`.
"""
if not isinstance(proof_environment.transform, Compose) and not isinstance(
proof_environment.transform, ObservationNorm
):
return
if key is None:
keys = list(proof_environment.base_env.observation_spec.keys(True, True))
key = keys.pop()
if len(keys):
raise RuntimeError(
f"More than one key exists in the observation_specs: {[key] + keys} were found, "
"thus initialize_observation_norm_transforms cannot infer which to compute the stats of."
)
if isinstance(proof_environment.transform, Compose):
for transform in proof_environment.transform:
if isinstance(transform, ObservationNorm) and not transform.initialized:
transform.init_stats(num_iter=num_iter, key=key)
elif not proof_environment.transform.initialized:
proof_environment.transform.init_stats(num_iter=num_iter, key=key)
def retrieve_observation_norms_state_dict(proof_environment: TransformedEnv):
"""Traverses the transforms of the environment and retrieves the :obj:`ObservationNorm` state dicts.
Returns a list of tuple (idx, state_dict) for each :obj:`ObservationNorm` transform in proof_environment
If the environment transforms do not contain any :obj:`ObservationNorm`, returns an empty list
Args:
proof_environment (EnvBase instance, optional): the :obj:``TransformedEnv` to retrieve the :obj:`ObservationNorm`
state dict from
"""
obs_norm_state_dicts = []
if isinstance(proof_environment.transform, Compose):
for idx, transform in enumerate(proof_environment.transform):
if isinstance(transform, ObservationNorm):
obs_norm_state_dicts.append((idx, transform.state_dict()))
if isinstance(proof_environment.transform, ObservationNorm):
obs_norm_state_dicts.append((0, proof_environment.transform.state_dict()))
return obs_norm_state_dicts
@dataclass
class EnvConfig:
"""Environment config struct."""
env_library: str = "gym"
# env_library used for the simulated environment. Default=gym
env_name: str = "Humanoid-v2"
# name of the environment to be created. Default=Humanoid-v2
env_task: str = ""
# task (if any) for the environment. Default=run
from_pixels: bool = False
# whether the environment output should be state vector(s) (default) or the pixels.
frame_skip: int = 1
# frame_skip for the environment. Note that this value does NOT impact the buffer size,
# maximum steps per trajectory, frames per batch or any other factor in the algorithm,
# e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4
# the actual number of frames retrieved will be 200e6. Default=1.
reward_scaling: Optional[float] = None
# scale of the reward.
reward_loc: float = 0.0
# location of the reward.
init_env_steps: int = 1000
# number of random steps to compute normalizing constants
vecnorm: bool = False
# Normalizes the environment observation and reward outputs with the running statistics obtained across processes.
norm_rewards: bool = False
# If True, rewards will be normalized on the fly. This may interfere with SAC update rule and should be used cautiously.
norm_stats: bool = True
# Deactivates the normalization based on random collection of data.
noops: int = 0
# number of random steps to do after reset. Default is 0
catframes: int = 0
# Number of frames to concatenate through time. Default is 0 (do not use CatFrames).
center_crop: Any = dataclass_field(default_factory=lambda: [])
# center crop size.
grayscale: bool = True
# Disables grayscale transform.
max_frames_per_traj: int = 1000
# Number of steps before a reset of the environment is called (if it has not been flagged as done before).
batch_transform: bool = False
# if ``True``, the transforms will be applied to the parallel env, and not to each individual env.\
image_size: int = 84
# if True and environment has discrete action space, then it is encoded as categorical values rather than one-hot.
categorical_action_encoding: bool = False