lora_fsdp_wrap_policy¶
- torchtune.utils.lora_fsdp_wrap_policy(modules_to_wrap: Set[Type]) Callable[[Module, bool, int], bool] [source]¶
A default policy for wrapping models trained with LoRA using FSDP.
FSDP’s default behavior is to allocate gradients at the level of FSDP-wrapped modules. This means that if any parameter in a given FSDP-wrapped module requires gradients, then memory will be allocated for gradients for the entire module.
In the case of LoRA, where only LoRA A and B matrices are trainable, this means that we need to wrap LoRA A and B submodules in their own FSDP units to maximize memory savings. After this is done, model will also be hierarchically wrapped based on nn.Module types specified in
modules_to_wrap
. This function assumes that (a) LoRA’s A and B matrices are the only trainable weights in the entire model, and (b) we have already setrequires_grad = True
on LoRA params.- Parameters:
modules_to_wrap (Set[Type]) – nn.Module types to recursively wrap
- Returns:
Wrapping policy that can be passed into
FullyShardedDataParallel
. Please see documentation forFSDPPolicyType
for additional details.- Return type:
FSDPPolicyType