get_full_finetune_fsdp_wrap_policy¶
- torchtune.utils.get_full_finetune_fsdp_wrap_policy(memory_efficient_fsdp_wrap: bool, modules_to_wrap: Set[Type]) Callable[[Module, bool, int], bool] [source]¶
Retrieves an FSDP wrapping policy based on the specified flags
memory_efficient_fsdp_wrap
andmodules_to_wrap
. Specifically, ifmemory_efficient_fsdp_wrap
is set toTrue
, the returned policy will wrap the model’s token embedding and output projection in addition to the modules specified to maximize memory savings.- Parameters:
memory_efficient_fsdp_wrap (bool) – If
True
, will also wrap embedding and output projection layers with FSDP.modules_to_wrap (Set[Type]) – Set of module types to wrap.
Note
memory_efficient_fsdp_wrap
memory improvements have currently only been verified on llama3 workloads where they provide ~15% memory improvement (when used alongside AC memory efficient wrapping). Other workloads have not been verified and may not see the same improvements.- Returns:
Wrapping policy that can be passed into
FullyShardedDataParallel
as theauto_wrap_policy
argument. Please see documentation forFSDPPolicyType
for additional details.- Return type:
FSDPPolicyType