Shortcuts

gather_cpu_state_dict

torchtune.training.gather_cpu_state_dict(sharded_sd: Dict[str, DTensor], is_rank_zero: bool, device: Optional[device] = None) Dict[str, Any][source]

Converting sharded state dict into a full state dict on CPU Returning non-empty result only on rank0 to avoid peaking CPU memory

Parameters:
  • sharded_sd (Dict[str, DTensor]) – Sharded state dict of DTensors

  • is_rank_zero (bool) – flag to check if the process is on rank 0

  • device (Optional[torch.device]) – device to use for sharded tensors. Default: None

Returns:

State dict on CPU

Return type:

Dict[str, Any]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources