reparametrize_as_dtype_state_dict_post_hook¶
- torchtune.modules.common_utils.reparametrize_as_dtype_state_dict_post_hook(model: Module, state_dict: Dict[str, Any], *args: Any, dtype: dtype = torch.bfloat16, offload_to_cpu: bool = True, **kwargs: Any)[source]¶
A state_dict hook that replaces NF4 tensors with their restored higher-precision weight and optionally offloads the restored weight to CPU. Use this hook to avoid increased peak GPU memory usage during checkpoint save when training with QLoRA.
This function is meant to be used with PyTorch’s
nn.Module._register_state_dict_hook
, i.e.>>> m = MyModule() >>> m._register_state_dict_hook(reparametrize_as_dtype_state_dict_post_hook)
If the hook is registered per the above process, this hook will be called _after_ the module’s
state_dict
method is called. The hook will replace allNF4Tensor
instances by unquantizing them to the original dtype, and optionally offload the restored weight to CPU.- Parameters:
model (nn.Module) – the model to take
state_dict()
onstate_dict (Dict[str, Any]) – the state dict to modify
*args (Any) – Unused args passed when running this as a state_dict hook.
dtype (torch.dpython:type) – the dtype to restore the weight to. Default is
torch.bfloat16
.offload_to_cpu (bool) – whether to offload the restored weight to CPU. Default is
True
.**kwargs (Any) – Unused keyword args passed when running this as a state_dict hook.