gather_cpu_state_dict
- torchtune.training.gather_cpu_state_dict(model: FSDPModule, is_rank_zero: bool, device: Optional[device] = None, adapter_weights_only: bool = False) Dict[str, Any] [source]
Converting sharded state dict into a full state dict on CPU Returning non-empty result only on rank0 to avoid peaking CPU memory Currenltly we can used distributed state dict API to process model without NF4Tensor. Otherwise, we need to manually gather any NF4 tensors until all-gather is supported in the NF4Tensor subclass TODO: add support for NF4Tensor at distributed state dict API
- Parameters:
model (FSDPModule) – Model to generate fully qualified names for cpu_state_dict
is_rank_zero (bool) – flag to check if the process is on rank 0
device (Optional[torch.device]) – device to use for sharded tensors. Default: None
adapter_weights_only (bool) – flag to check if only trainable parameters should be returned. Default: False
- Returns:
State dict on CPU
- Return type:
Dict[str, Any]