Shortcuts

lora_fsdp_wrap_policy

torchtune.utils.lora_fsdp_wrap_policy(modules_to_wrap: Set[Type]) Callable[[Module, bool, int], bool][source]

A default policy for wrapping models trained with LoRA using FSDP.

FSDP’s default behavior is to allocate gradients at the level of FSDP-wrapped modules. This means that if any parameter in a given FSDP-wrapped module requires gradients, then memory will be allocated for gradients for the entire module.

In the case of LoRA, where only LoRA A and B matrices are trainable, this means that we need to wrap LoRA A and B submodules in their own FSDP units to maximize memory savings. After this is done, model will also be hierarchically wrapped based on nn.Module types specified in modules_to_wrap. This function assumes that (a) LoRA’s A and B matrices are the only trainable weights in the entire model, and (b) we have already set requires_grad = True on LoRA params.

Parameters:

modules_to_wrap (Set[Type]) – nn.Module types to recursively wrap

Returns:

Wrapping policy that can be passed into FullyShardedDataParallel. Please see documentation for FSDPPolicyType for additional details.

Return type:

FSDPPolicyType

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources