Source code for torchtune.training.activations
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Union
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.utils.checkpoint import checkpoint
# Uses PTD FSDP AC wrapper
# currently selective per layer checkpointing are supported
def checkpoint_wrapper(module, ac_mode, ac_style):
if ac_mode == "full":
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
)
# selective layer checkpointing...some checks in case we receive '2' or 2...
elif ac_mode == "selective":
"""enables selective checkpointing of candidate layers.
Usage:
'selective_ac_option' with a positive 'int' value in config controls which layers to checkpoint.
1 == checkpointing every one (all).
2 == checkpoint every 2nd one
"""
every_x_layer = int(ac_style)
if not (every_x_layer >= 0):
raise ValueError(
f"Selective layer AC policy (every_x_layer) expects a positive integer, received {every_x_layer}"
)
checkpoint_wrapper.__dict__.setdefault("_count", 0)
checkpoint_wrapper._count += 1
if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
)
# skip activation checkpointing and store activations for this layer
else:
return module
else:
raise NotImplementedError(
"Unknown AC type or AC config. Only selective op and selective layer ac implemented currently."
)
[docs]def apply_selective_activation_checkpointing(
model: nn.Module,
ac_mode: str,
ac_option: Optional[Union[int, str]],
) -> None:
"""Utility to setup activation checkpointing and wrap the model for checkpointing.
Args:
model (nn.Module): Model to setup activation checkpointing.
ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective']
ac_option (Optional[Union[int, str]]): Activation checkpointing option. If ac_mode is
"selective", ac_option can be an integer or a string representing the number of layers
to checkpoint. If ac_mode is "selective" and ac_option is "op", then selective op ac is run.
If ac_mode is "none" or "full", ac_option is ignored.
"""
for layer_id, transformer_block in enumerate(model.layers):
if ac_mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block,
ac_mode,
ac_option,
)
model.layers[layer_id] = transformer_block