EnvCreator¶
- class torchrl.envs.EnvCreator(create_env_fn: Callable[[...], EnvBase], create_env_kwargs: Optional[Dict] = None, share_memory: bool = True)[source]¶
Environment creator class.
EnvCreator is a generic environment creator class that can substitute lambda functions when creating environments in multiprocessing contexts. If the environment created on a subprocess must share information with the main process (e.g. for the VecNorm transform), EnvCreator will pass the pointers to the tensordicts in shared memory to each process such that all of them are synchronised.
- Parameters:
create_env_fn (callable) – a callable that returns an EnvBase instance.
create_env_kwargs (dict, optional) – the kwargs of the env creator.
share_memory (bool, optional) – if False, the resulting tensordict from the environment won’t be placed in shared memory.
Examples
>>> # We create the same environment on 2 processes using VecNorm >>> # and check that the discounted count of observations match on >>> # both workers, even if one has not executed any step >>> import time >>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs.transforms import VecNorm, TransformedEnv >>> from torchrl.envs import EnvCreator >>> from torch import multiprocessing as mp >>> env_fn = lambda: TransformedEnv(GymEnv("Pendulum-v1"), VecNorm()) >>> env_creator = EnvCreator(env_fn) >>> >>> def test_env1(env_creator): ... env = env_creator() ... tensordict = env.reset() ... for _ in range(10): ... env.rand_step(tensordict) ... if tensordict.get(("next", "done")): ... tensordict = env.reset(tensordict) ... print("env 1: ", env.transform._td.get(("next", "observation_count"))) >>> >>> def test_env2(env_creator): ... env = env_creator() ... time.sleep(5) ... print("env 2: ", env.transform._td.get(("next", "observation_count"))) >>> >>> if __name__ == "__main__": ... ps = [] ... p1 = mp.Process(target=test_env1, args=(env_creator,)) ... p1.start() ... ps.append(p1) ... p2 = mp.Process(target=test_env2, args=(env_creator,)) ... p2.start() ... ps.append(p1) ... for p in ps: ... p.join() env 1: tensor([11.9934]) env 2: tensor([11.9934])