Shortcuts

get_full_finetune_fsdp_wrap_policy

torchtune.utils.get_full_finetune_fsdp_wrap_policy(memory_efficient_fsdp_wrap: bool, modules_to_wrap: Set[Type]) Callable[[Module, bool, int], bool][source]

Retrieves an FSDP wrapping policy based on the specified flags memory_efficient_fsdp_wrap and modules_to_wrap. Specifically, if memory_efficient_fsdp_wrap is set to True, the returned policy will wrap the model’s token embedding and output projection in addition to the modules specified to maximize memory savings.

Parameters:
  • memory_efficient_fsdp_wrap (bool) – If True, will also wrap embedding and output projection layers with FSDP.

  • modules_to_wrap (Set[Type]) – Set of module types to wrap.

Note

memory_efficient_fsdp_wrap memory improvements have currently only been verified on llama3 workloads where they provide ~15% memory improvement (when used alongside AC memory efficient wrapping). Other workloads have not been verified and may not see the same improvements.

Returns:

Wrapping policy that can be passed into FullyShardedDataParallel as the auto_wrap_policy argument. Please see documentation for FSDPPolicyType for additional details.

Return type:

FSDPPolicyType

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources