lora_fsdp_wrap_policy¶
- torchtune.training.lora_fsdp_wrap_policy(modules_to_wrap: Set[Type]) Callable[[Module, bool, int], bool] [source]¶
A default policy for wrapping models trained with LoRA using FSDP.
FSDP’s default behavior is to allocate gradients at the level of FSDP-wrapped modules. This means that if any parameter in a given FSDP-wrapped module requires gradients, then memory will be allocated for gradients for the entire module.
In the case of LoRA, where only the adapters are trainable, this means that we need to wrap the adapter submodules in their own FSDP units to maximize memory savings. After this is done, model will also be hierarchically wrapped based on nn.Module types specified in
modules_to_wrap
.- Parameters:
modules_to_wrap (Set[Type]) – nn.Module types to recursively wrap
- Returns:
Wrapping policy that can be passed into
FullyShardedDataParallel
. Please see documentation forFSDPPolicyType
for additional details.- Return type:
FSDPPolicyType