Shortcuts

torchtune.utils.FSDPPolicyType

torchtune.utils.FSDPPolicyType

A datatype for a function that can be used as an FSDP wrapping policy. In particular, this type denotes a function that can accept an nn.Module, a boolean flag, and an integer and return a boolean indicating whether the module should be wrapped with FSDP. Objects of this type can be directly passed into PyTorch FSDP’s auto_wrap_policy argument to specify how FSDP wraps submodules.

The below function serves as an example of creating and returning a function that obeys the contract of FSDPPolicyType:

def get_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], min_num_params: int):

    def my_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], recurse: bool, min_num_params: int) -> bool:
        if recurse:
            return True
        # Wrap layers that are of type in ``modules_to_wrap`` and layers with more than min_num_params

        return isinstance(module, tuple(modules_to_wrap)) or sum(p.numel() for p in module.parameters()) > 1000

    return functools.partial(my_fsdp_policy, modules_to_wrap=modules_to_wrap)

Please see documentation of auto_wrap_policy at https://pytorch.org/docs/stable/fsdp.html for additional details.

alias of Callable[[Module, bool, int], bool]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources