torchtnt.utils.prepare_module.prepare_module¶
-
torchtnt.utils.prepare_module.
prepare_module
(module: Module, device: device, *, strategy: Optional[Union[Strategy, str]] = None, swa_params: Optional[SWAParams] = None, torch_compile_params: Optional[TorchCompileParams] = None, activation_checkpoint_params: Optional[ActivationCheckpointParams] = None) Module ¶ Utility to move a module to device, set up parallelism, activation checkpointing and compile.
Parameters: - module – module to be used.
- device – device to which module will be moved.
- strategy – the data parallelization strategy to be used. if a string, must be one of
ddp
orfsdp
. - swa_params – params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging.
- torch_compile_params – params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html.
- activation_checkpoint_params – params for enabling activation checkpointing.