set_activation_checkpointing¶
- torchtune.utils.set_activation_checkpointing(model: Module, auto_wrap_policy: Union[Set[Type], Callable[[Module, bool, int], bool]], **kwargs) None [source]¶
Utility to apply activation checkpointing to the passed-in model.
- Parameters:
model (nn.Module) – Model to apply activation checkpointing to.
auto_wrap_policy (ACWrapPolicyType) – Policy to wrap module. This can either be a set of
nn.Module
types, in which case, modules of the specified type(s) will be wrapped individually with activation checkpointing, or acallable
policy describing how to wrap the model with activation checkpointing. For more information on authoring custom policies, please see this tutorial: https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html#transformer-wrapping-policy.**kwargs – additional arguments to pass to
torch.distributed
activation checkpointing.