Shortcuts

Source code for torchrl.trainers.helpers.collectors

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable

from tensordict import TensorDictBase

from tensordict.nn import ProbabilisticTensorDictSequential, TensorDictModuleWrapper

from torchrl.collectors.collectors import (
    DataCollectorBase,
    MultiaSyncDataCollector,
    MultiSyncDataCollector,
    SyncDataCollector,
)
from torchrl.data.postprocs import MultiStep
from torchrl.envs.batched_envs import ParallelEnv
from torchrl.envs.common import EnvBase


[docs]def sync_async_collector( env_fns: Callable | list[Callable], env_kwargs: dict | list[dict] | None, num_env_per_collector: int | None = None, num_collectors: int | None = None, **kwargs, ) -> MultiaSyncDataCollector: """Runs asynchronous collectors, each running synchronous environments. .. aafig:: +----------------------------------------------------------------------+ | "MultiaSyncDataCollector" | | |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | | | | | | | | | | "actor" | | | "actor" | | | | | | | | | "step" | "step" | "actor" | | | | | | | | | | | | | "step" | "step" | | | | | | | | | | "actor | "step" | "step" | "actor" | | | | | | | | | "yield batch 1" | "actor" | |"collect, train"| | | | | | | "step" | "step" | | "yield batch 2" |"collect, train"| | | | | | | | | | "yield batch 3" | |"collect, train"| | | | | | | +----------------------------------------------------------------------+ Environment types can be identical or different. In the latter case, env_fns should be a list with all the creator fns for the various envs, and the policy should handle those envs in batch. Args: env_fns: Callable (or list of Callables) returning an instance of EnvBase class. env_kwargs: Optional. Dictionary (or list of dictionaries) containing the kwargs for the environment being created. num_env_per_collector: Number of environments per data collector. The product num_env_per_collector * num_collectors should be less or equal to the number of workers available. num_collectors: Number of data collectors to be run in parallel. **kwargs: Other kwargs passed to the data collectors """ return _make_collector( MultiaSyncDataCollector, env_fns=env_fns, env_kwargs=env_kwargs, num_env_per_collector=num_env_per_collector, num_collectors=num_collectors, **kwargs, )
[docs]def sync_sync_collector( env_fns: Callable | list[Callable], env_kwargs: dict | list[dict] | None, num_env_per_collector: int | None = None, num_collectors: int | None = None, **kwargs, ) -> SyncDataCollector | MultiSyncDataCollector: """Runs synchronous collectors, each running synchronous environments. E.g. .. aafig:: +----------------------------------------------------------------------+ | "MultiSyncDataCollector" | | |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | | "Collector 1" | "Collector 2" | "Collector 3" | Main | |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | | | | | | | | | | "actor" | | | "actor" | | | | | | | | | "step" | "step" | "actor" | | | | | | | | | | | | | "step" | "step" | | | | | | | | | | "actor" | "step" | "step" | "actor" | | | | | | | | | | "actor" | | | | | | | | | "yield batch of traj 1"------->"collect, train"| | | | | "step" | "step" | "step" | "step" | "step" | "step" | | | | | | | | | | | "actor" | "actor" | | | | | | "step" | "step" | "actor" | | | | | | | | | "step" | "step" | "actor" | "step" | "step" | | | | | | | | | | "actor" | | "actor" | | | "yield batch of traj 2"------->"collect, train"| | | | +----------------------------------------------------------------------+ Envs can be identical or different. In the latter case, env_fns should be a list with all the creator fns for the various envs, and the policy should handle those envs in batch. Args: env_fns: Callable (or list of Callables) returning an instance of EnvBase class. env_kwargs: Optional. Dictionary (or list of dictionaries) containing the kwargs for the environment being created. num_env_per_collector: Number of environments per data collector. The product num_env_per_collector * num_collectors should be less or equal to the number of workers available. num_collectors: Number of data collectors to be run in parallel. **kwargs: Other kwargs passed to the data collectors """ if num_collectors == 1: if "device" in kwargs: kwargs["device"] = kwargs.pop("device") if "storing_device" in kwargs: kwargs["storing_device"] = kwargs.pop("storing_device") return _make_collector( SyncDataCollector, env_fns=env_fns, env_kwargs=env_kwargs, num_env_per_collector=num_env_per_collector, num_collectors=num_collectors, **kwargs, ) return _make_collector( MultiSyncDataCollector, env_fns=env_fns, env_kwargs=env_kwargs, num_env_per_collector=num_env_per_collector, num_collectors=num_collectors, **kwargs, )
def _make_collector( collector_class: type, env_fns: Callable | list[Callable], env_kwargs: dict | list[dict] | None, policy: Callable[[TensorDictBase], TensorDictBase], max_frames_per_traj: int = -1, frames_per_batch: int = 200, total_frames: int | None = None, postproc: Callable | None = None, num_env_per_collector: int | None = None, num_collectors: int | None = None, **kwargs, ) -> DataCollectorBase: if env_kwargs is None: env_kwargs = {} if isinstance(env_fns, list): num_env = len(env_fns) if num_env_per_collector is None: num_env_per_collector = -(num_env // -num_collectors) elif num_collectors is None: num_collectors = -(num_env // -num_env_per_collector) else: if num_env_per_collector * num_collectors < num_env: raise ValueError( f"num_env_per_collector * num_collectors={num_env_per_collector * num_collectors} " f"has been found to be less than num_env={num_env}" ) else: try: num_env = num_env_per_collector * num_collectors env_fns = [env_fns for _ in range(num_env)] except (TypeError): raise Exception( "num_env was not a list but num_env_per_collector and num_collectors were not both specified," f"got num_env_per_collector={num_env_per_collector} and num_collectors={num_collectors}" ) if not isinstance(env_kwargs, list): env_kwargs = [env_kwargs for _ in range(num_env)] env_fns_split = [ env_fns[i : i + num_env_per_collector] for i in range(0, num_env, num_env_per_collector) ] env_kwargs_split = [ env_kwargs[i : i + num_env_per_collector] for i in range(0, num_env, num_env_per_collector) ] if len(env_fns_split) != num_collectors: raise RuntimeError( f"num_collectors={num_collectors} differs from len(env_fns_split)={len(env_fns_split)}" ) if num_env_per_collector == 1: env_fns = [_env_fn[0] for _env_fn in env_fns_split] env_kwargs = [_env_kwargs[0] for _env_kwargs in env_kwargs_split] else: env_fns = [ lambda _env_fn=_env_fn, _env_kwargs=_env_kwargs: ParallelEnv( num_workers=len(_env_fn), create_env_fn=_env_fn, create_env_kwargs=_env_kwargs, ) for _env_fn, _env_kwargs in zip(env_fns_split, env_kwargs_split) ] env_kwargs = None if collector_class is SyncDataCollector: if len(env_fns) > 1: raise RuntimeError( f"Something went wrong: expected a single env constructor but got {len(env_fns)}" ) env_fns = env_fns[0] env_kwargs = env_kwargs[0] return collector_class( create_env_fn=env_fns, create_env_kwargs=env_kwargs, policy=policy, total_frames=total_frames, max_frames_per_traj=max_frames_per_traj, frames_per_batch=frames_per_batch, postproc=postproc, **kwargs, )
[docs]def make_collector_offpolicy( make_env: Callable[[], EnvBase], actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential), cfg: DictConfig, # noqa: F821 make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Returns a data collector for off-policy sota-implementations. Args: make_env (Callable): environment creator actor_model_explore (SafeModule): Model instance used for evaluation and exploration update cfg (DictConfig): config for creating collector object make_env_kwargs (dict): kwargs for the env creator """ if cfg.async_collection: collector_helper = sync_async_collector else: collector_helper = sync_sync_collector if cfg.multi_step: ms = MultiStep( gamma=cfg.gamma, n_steps=cfg.n_steps_return, ) else: ms = None env_kwargs = {} if make_env_kwargs is not None and isinstance(make_env_kwargs, dict): env_kwargs.update(make_env_kwargs) elif make_env_kwargs is not None: env_kwargs = make_env_kwargs cfg.collector_device = ( cfg.collector_device if len(cfg.collector_device) > 1 else cfg.collector_device[0] ) collector_helper_kwargs = { "env_fns": make_env, "env_kwargs": env_kwargs, "policy": actor_model_explore, "max_frames_per_traj": cfg.max_frames_per_traj, "frames_per_batch": cfg.frames_per_batch, "total_frames": cfg.total_frames, "postproc": ms, "num_env_per_collector": 1, # we already took care of building the make_parallel_env function "num_collectors": -cfg.num_workers // -cfg.env_per_collector, "device": cfg.collector_device, "storing_device": cfg.collector_device, "init_random_frames": cfg.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) collector.set_seed(cfg.seed) return collector
[docs]def make_collector_onpolicy( make_env: Callable[[], EnvBase], actor_model_explore: (TensorDictModuleWrapper | ProbabilisticTensorDictSequential), cfg: DictConfig, # noqa: F821 make_env_kwargs: dict | None = None, ) -> DataCollectorBase: """Makes a collector in on-policy settings. Args: make_env (Callable): environment creator actor_model_explore (SafeModule): Model instance used for evaluation and exploration update cfg (DictConfig): config for creating collector object make_env_kwargs (dict): kwargs for the env creator """ collector_helper = sync_sync_collector ms = None env_kwargs = {} if make_env_kwargs is not None and isinstance(make_env_kwargs, dict): env_kwargs.update(make_env_kwargs) elif make_env_kwargs is not None: env_kwargs = make_env_kwargs cfg.collector_device = ( cfg.collector_device if len(cfg.collector_device) > 1 else cfg.collector_device[0] ) collector_helper_kwargs = { "env_fns": make_env, "env_kwargs": env_kwargs, "policy": actor_model_explore, "max_frames_per_traj": cfg.max_frames_per_traj, "frames_per_batch": cfg.frames_per_batch, "total_frames": cfg.total_frames, "postproc": ms, "num_env_per_collector": 1, # we already took care of building the make_parallel_env function "num_collectors": -cfg.num_workers // -cfg.env_per_collector, "device": cfg.collector_device, "storing_device": cfg.collector_device, "split_trajs": True, # trajectories must be separated in online settings "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) collector.set_seed(cfg.seed) return collector
@dataclass class OnPolicyCollectorConfig: """On-policy collector config struct.""" collector_device: Any = field(default_factory=lambda: ["cpu"]) # device on which the data collector should store the trajectories to be passed to this script. # If the collector device differs from the policy device (cuda:0 if available), then the # weights of the collector policy are synchronized with collector.update_policy_weights_(). pin_memory: bool = False # if ``True``, the data collector will call pin_memory before dispatching tensordicts onto the passing device frames_per_batch: int = 1000 # number of steps executed in the environment per collection. # This value represents how many steps will the data collector execute and return in *each* # environment that has been created in between two rounds of optimization # (see the optim_steps_per_batch above). # On the one hand, a low value will enhance the data throughput between processes in async # settings, which can make the accessing of data a computational bottleneck. # High values will on the other hand lead to greater tensor sizes in memory and disk to be # written and read at each global iteration. One should look at the number of frames per second # in the log to assess the efficiency of the configuration. total_frames: int = 50000000 # total number of frames collected for training. Does account for frame_skip (i.e. will be # divided by the frame_skip). Default=50e6. num_workers: int = 32 # Number of workers used for data collection. env_per_collector: int = 8 # Number of environments per collector. If the env_per_collector is in the range: # 1<env_per_collector<=num_workers, then the collector runs # ceil(num_workers/env_per_collector) in parallel and executes the policy steps synchronously # for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created seed: int = 42 # seed used for the environment, pytorch and numpy. exploration_type: str = "random" # exploration mode of the data collector. async_collection: bool = False # whether data collection should be done asynchronously. Asynchronous data collection means # that the data collector will keep on running the environment with the previous weights # configuration while the optimization loop is being done. If the algorithm is trained # synchronously, data collection and optimization will occur iteratively, not concurrently. @dataclass class OffPolicyCollectorConfig(OnPolicyCollectorConfig): """Off-policy collector config struct.""" multi_step: bool = False # whether multi-step rewards should be used. n_steps_return: int = 3 # If multi_step is set to True, this value defines the number of steps to look ahead for the reward computation. init_random_frames: int = 50000

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources