torchtune.utils.FSDPPolicyType¶
- torchtune.utils.FSDPPolicyType¶
A datatype for a function that can be used as an FSDP wrapping policy. In particular, this type denotes a function that can accept an nn.Module, a boolean flag, and an integer and return a boolean indicating whether the module should be wrapped with FSDP. Objects of this type can be directly passed into PyTorch FSDP’s
auto_wrap_policy
argument to specify how FSDP wraps submodules.The below function serves as an example of creating and returning a function that obeys the contract of
FSDPPolicyType
:def get_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], min_num_params: int): def my_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], recurse: bool, min_num_params: int) -> bool: if recurse: return True # Wrap layers that are of type in ``modules_to_wrap`` and layers with more than min_num_params return isinstance(module, tuple(modules_to_wrap)) or sum(p.numel() for p in module.parameters()) > 1000 return functools.partial(my_fsdp_policy, modules_to_wrap=modules_to_wrap)
Please see documentation of
auto_wrap_policy
at https://pytorch.org/docs/stable/fsdp.html for additional details.