gather_cpu_state_dict¶
- torchtune.training.gather_cpu_state_dict(sharded_sd: Dict[str, DTensor], is_rank_zero: bool, device: Optional[device] = None) Dict[str, Any] [source]¶
Converting sharded state dict into a full state dict on CPU Returning non-empty result only on rank0 to avoid peaking CPU memory
- Parameters:
sharded_sd (Dict[str, DTensor]) – Sharded state dict of DTensors
is_rank_zero (bool) – flag to check if the process is on rank 0
device (Optional[torch.device]) – device to use for sharded tensors. Default: None
- Returns:
State dict on CPU
- Return type:
Dict[str, Any]