# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
from tensordict import TensorDictBase
from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModuleWrapper
from torchrl.collectors.collectors import (
from torchrl.data.postprocs import MultiStep
from torchrl.envs.batched_envs import ParallelEnv
from torchrl.envs.common import EnvBase
[docs]def sync_async_collector(
env_fns: Callable | list[Callable],
env_kwargs: dict | list[dict] | None,
num_env_per_collector: int | None = None,
num_collectors: int | None = None,
) -> MultiaSyncDataCollector:
"""Runs asynchronous collectors, each running synchronous environments.
.. aafig::
| "MultiaSyncDataCollector" | |
|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| |
| "Collector 1" | "Collector 2" | "Collector 3" | "Main" |
| "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | |
|"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | |
| | | | | | | |
| "actor" | | | "actor" | |
| | | | | |
| "step" | "step" | "actor" | | |
| | | | | |
| | | | "step" | "step" | |
| | | | | | |
| "actor | "step" | "step" | "actor" | |
| | | | | |
| "yield batch 1" | "actor" | |"collect, train"|
| | | | |
| "step" | "step" | | "yield batch 2" |"collect, train"|
| | | | | |
| | | "yield batch 3" | |"collect, train"|
| | | | | |
Environment types can be identical or different. In the latter case, env_fns should be a list with all the creator
fns for the various envs,
and the policy should handle those envs in batch.
env_fns: Callable (or list of Callables) returning an instance of EnvBase class.
env_kwargs: Optional. Dictionary (or list of dictionaries) containing the kwargs for the environment being created.
num_env_per_collector: Number of environments per data collector. The product
num_env_per_collector * num_collectors should be less or equal to the number of workers available.
num_collectors: Number of data collectors to be run in parallel.
**kwargs: Other kwargs passed to the data collectors
return _make_collector(
[docs]def sync_sync_collector(
env_fns: Callable | list[Callable],
env_kwargs: dict | list[dict] | None,
num_env_per_collector: int | None = None,
num_collectors: int | None = None,
) -> SyncDataCollector | MultiSyncDataCollector:
"""Runs synchronous collectors, each running synchronous environments.
.. aafig::
| "MultiSyncDataCollector" | |
|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| |
| "Collector 1" | "Collector 2" | "Collector 3" | Main |
| "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | |
|"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | |
| | | | | | | |
| "actor" | | | "actor" | |
| | | | | |
| "step" | "step" | "actor" | | |
| | | | | |
| | | | "step" | "step" | |
| | | | | | |
| "actor" | "step" | "step" | "actor" | |
| | | | | |
| | "actor" | | |
| | | | |
| "yield batch of traj 1"------->"collect, train"|
| | |
| "step" | "step" | "step" | "step" | "step" | "step" | |
| | | | | | | |
| "actor" | "actor" | | | |
| | "step" | "step" | "actor" | |
| | | | | |
| "step" | "step" | "actor" | "step" | "step" | |
| | | | | | |
| "actor" | | "actor" | |
| "yield batch of traj 2"------->"collect, train"|
| | |
Envs can be identical or different. In the latter case, env_fns should be a list with all the creator fns
for the various envs,
and the policy should handle those envs in batch.
env_fns: Callable (or list of Callables) returning an instance of EnvBase class.
env_kwargs: Optional. Dictionary (or list of dictionaries) containing the kwargs for the environment being created.
num_env_per_collector: Number of environments per data collector. The product
num_env_per_collector * num_collectors should be less or equal to the number of workers available.
num_collectors: Number of data collectors to be run in parallel.
**kwargs: Other kwargs passed to the data collectors
if num_collectors == 1:
if "device" in kwargs:
kwargs["device"] = kwargs.pop("device")
if "storing_device" in kwargs:
kwargs["storing_device"] = kwargs.pop("storing_device")
return _make_collector(
return _make_collector(
def _make_collector(
collector_class: type,
env_fns: Callable | list[Callable],
env_kwargs: dict | list[dict] | None,
policy: Callable[[TensorDictBase], TensorDictBase],
max_frames_per_traj: int = -1,
frames_per_batch: int = 200,
total_frames: int | None = None,
postproc: Callable | None = None,
num_env_per_collector: int | None = None,
num_collectors: int | None = None,
) -> DataCollectorBase:
if env_kwargs is None:
env_kwargs = {}
if isinstance(env_fns, list):
num_env = len(env_fns)
if num_env_per_collector is None:
num_env_per_collector = -(num_env // -num_collectors)
elif num_collectors is None:
num_collectors = -(num_env // -num_env_per_collector)
if num_env_per_collector * num_collectors < num_env:
raise ValueError(
f"num_env_per_collector * num_collectors={num_env_per_collector * num_collectors} "
f"has been found to be less than num_env={num_env}"
num_env = num_env_per_collector * num_collectors
env_fns = [env_fns for _ in range(num_env)]
except (TypeError):
raise Exception(
"num_env was not a list but num_env_per_collector and num_collectors were not both specified,"
f"got num_env_per_collector={num_env_per_collector} and num_collectors={num_collectors}"
if not isinstance(env_kwargs, list):
env_kwargs = [env_kwargs for _ in range(num_env)]
env_fns_split = [
env_fns[i : i + num_env_per_collector]
for i in range(0, num_env, num_env_per_collector)
env_kwargs_split = [
env_kwargs[i : i + num_env_per_collector]
for i in range(0, num_env, num_env_per_collector)
if len(env_fns_split) != num_collectors:
raise RuntimeError(
f"num_collectors={num_collectors} differs from len(env_fns_split)={len(env_fns_split)}"
if num_env_per_collector == 1:
env_fns = [_env_fn[0] for _env_fn in env_fns_split]
env_kwargs = [_env_kwargs[0] for _env_kwargs in env_kwargs_split]
env_fns = [
lambda _env_fn=_env_fn, _env_kwargs=_env_kwargs: ParallelEnv(
for _env_fn, _env_kwargs in zip(env_fns_split, env_kwargs_split)
env_kwargs = None
if collector_class is SyncDataCollector:
if len(env_fns) > 1:
raise RuntimeError(
f"Something went wrong: expected a single env constructor but got {len(env_fns)}"
env_fns = env_fns[0]
env_kwargs = env_kwargs[0]
return collector_class(
[docs]def make_collector_offpolicy(
make_env: Callable[[], EnvBase],
actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential),
cfg: DictConfig, # noqa: F821
make_env_kwargs: dict | None = None,
) -> DataCollectorBase:
"""Returns a data collector for off-policy sota-implementations.
make_env (Callable): environment creator
actor_model_explore (SafeModule): Model instance used for evaluation and exploration update
cfg (DictConfig): config for creating collector object
make_env_kwargs (dict): kwargs for the env creator
if cfg.async_collection:
collector_helper = sync_async_collector
collector_helper = sync_sync_collector
if cfg.multi_step:
ms = MultiStep(
ms = None
env_kwargs = {}
if make_env_kwargs is not None and isinstance(make_env_kwargs, dict):
elif make_env_kwargs is not None:
env_kwargs = make_env_kwargs
cfg.collector_device = (
if len(cfg.collector_device) > 1
else cfg.collector_device[0]
collector_helper_kwargs = {
"env_fns": make_env,
"env_kwargs": env_kwargs,
"policy": actor_model_explore,
"max_frames_per_traj": cfg.max_frames_per_traj,
"frames_per_batch": cfg.frames_per_batch,
"total_frames": cfg.total_frames,
"postproc": ms,
"num_env_per_collector": 1,
# we already took care of building the make_parallel_env function
"num_collectors": -cfg.num_workers // -cfg.env_per_collector,
"device": cfg.collector_device,
"storing_device": cfg.collector_device,
"init_random_frames": cfg.init_random_frames,
"split_trajs": True,
# trajectories must be separated if multi-step is used
"exploration_type": cfg.exploration_type,
collector = collector_helper(**collector_helper_kwargs)
return collector
[docs]def make_collector_onpolicy(
make_env: Callable[[], EnvBase],
actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential),
cfg: DictConfig, # noqa: F821
make_env_kwargs: dict | None = None,
) -> DataCollectorBase:
"""Makes a collector in on-policy settings.
make_env (Callable): environment creator
actor_model_explore (SafeModule): Model instance used for evaluation and exploration update
cfg (DictConfig): config for creating collector object
make_env_kwargs (dict): kwargs for the env creator
collector_helper = sync_sync_collector
ms = None
env_kwargs = {}
if make_env_kwargs is not None and isinstance(make_env_kwargs, dict):
elif make_env_kwargs is not None:
env_kwargs = make_env_kwargs
cfg.collector_device = (
if len(cfg.collector_device) > 1
else cfg.collector_device[0]
collector_helper_kwargs = {
"env_fns": make_env,
"env_kwargs": env_kwargs,
"policy": actor_model_explore,
"max_frames_per_traj": cfg.max_frames_per_traj,
"frames_per_batch": cfg.frames_per_batch,
"total_frames": cfg.total_frames,
"postproc": ms,
"num_env_per_collector": 1,
# we already took care of building the make_parallel_env function
"num_collectors": -cfg.num_workers // -cfg.env_per_collector,
"device": cfg.collector_device,
"storing_device": cfg.collector_device,
"split_trajs": True,
# trajectories must be separated in online settings
"exploration_type": cfg.exploration_type,
collector = collector_helper(**collector_helper_kwargs)
return collector
class OnPolicyCollectorConfig:
"""On-policy collector config struct."""
collector_device: Any = field(default_factory=lambda: ["cpu"])
# device on which the data collector should store the trajectories to be passed to this script.
# If the collector device differs from the policy device (cuda:0 if available), then the
# weights of the collector policy are synchronized with collector.update_policy_weights_().
pin_memory: bool = False
# if ``True``, the data collector will call pin_memory before dispatching tensordicts onto the passing device
frames_per_batch: int = 1000
# number of steps executed in the environment per collection.
# This value represents how many steps will the data collector execute and return in *each*
# environment that has been created in between two rounds of optimization
# (see the optim_steps_per_batch above).
# On the one hand, a low value will enhance the data throughput between processes in async
# settings, which can make the accessing of data a computational bottleneck.
# High values will on the other hand lead to greater tensor sizes in memory and disk to be
# written and read at each global iteration. One should look at the number of frames per second
# in the log to assess the efficiency of the configuration.
total_frames: int = 50000000
# total number of frames collected for training. Does account for frame_skip (i.e. will be
# divided by the frame_skip). Default=50e6.
num_workers: int = 32
# Number of workers used for data collection.
env_per_collector: int = 8
# Number of environments per collector. If the env_per_collector is in the range:
# 1<env_per_collector<=num_workers, then the collector runs
# ceil(num_workers/env_per_collector) in parallel and executes the policy steps synchronously
# for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created
seed: int = 42
# seed used for the environment, pytorch and numpy.
exploration_type: str = "random"
# exploration mode of the data collector.
async_collection: bool = False
# whether data collection should be done asynchronously. Asynchronous data collection means
# that the data collector will keep on running the environment with the previous weights
# configuration while the optimization loop is being done. If the algorithm is trained
# synchronously, data collection and optimization will occur iteratively, not concurrently.
class OffPolicyCollectorConfig(OnPolicyCollectorConfig):
"""Off-policy collector config struct."""
multi_step: bool = False
# whether multi-step rewards should be used.
n_steps_return: int = 3
# If multi_step is set to True, this value defines the number of steps to look ahead for the reward computation.
init_random_frames: int = 50000