Shortcuts

reparametrize_as_dtype_state_dict_post_hook

torchtune.modules.common_utils.reparametrize_as_dtype_state_dict_post_hook(model: Module, state_dict: Dict[str, Any], *args: Any, dtype: dtype = torch.bfloat16, offload_to_cpu: bool = True, **kwargs: Any)[source]

A state_dict hook that replaces NF4 tensors with their restored higher-precision weight and optionally offloads the restored weight to CPU. Use this hook to avoid increased peak GPU memory usage during checkpoint save when training with QLoRA.

This function is meant to be used with PyTorch’s nn.Module._register_state_dict_hook, i.e.

>>> m = MyModule()
>>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook)

If the hook is registered per the above process, this hook will be called _after_ the module’s state_dict method is called. The hook will replace all NF4Tensor instances by unquantizing them to the original dtype, and optionally offload the restored weight to CPU.

Parameters:
  • model (nn.Module) – the model to take state_dict() on

  • state_dict (Dict[str, Any]) – the state dict to modify

  • *args (Any) – Unused args passed when running this as a state_dict hook.

  • dtype (torch.dpython:type) – the dtype to restore the weight to. Default is torch.bfloat16.

  • offload_to_cpu (bool) – whether to offload the restored weight to CPU. Default is True.

  • **kwargs (Any) – Unused keyword args passed when running this as a state_dict hook.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources