Shortcuts

gather_cpu_state_dict

torchtune.training.gather_cpu_state_dict(model: FSDPModule, is_rank_zero: bool, device: Optional[device] = None, adapter_weights_only: bool = False) Dict[str, Any][source]

Converting sharded state dict into a full state dict on CPU Returning non-empty result only on rank0 to avoid peaking CPU memory Currenltly we can used distributed state dict API to process model without NF4Tensor. Otherwise, we need to manually gather any NF4 tensors until all-gather is supported in the NF4Tensor subclass TODO: add support for NF4Tensor at distributed state dict API

Parameters:
  • model (FSDPModule) – Model to generate fully qualified names for cpu_state_dict

  • is_rank_zero (bool) – flag to check if the process is on rank 0

  • device (Optional[torch.device]) – device to use for sharded tensors. Default: None

  • adapter_weights_only (bool) – flag to check if only trainable parameters should be returned. Default: False

Returns:

State dict on CPU

Return type:

Dict[str, Any]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources