Shortcuts

Source code for torchtune.training.activations

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Union

from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper as ptd_checkpoint_wrapper,
    CheckpointImpl,
)
from torch.utils.checkpoint import checkpoint


# Uses PTD FSDP AC wrapper
# currently selective per layer checkpointing are supported
def checkpoint_wrapper(module, ac_mode, ac_style):

    if ac_mode == "full":
        return ptd_checkpoint_wrapper(
            module,
            checkpoint_impl=CheckpointImpl.NO_REENTRANT,
            checkpoint_fn=checkpoint,
            use_reentrant=False,
            preserve_rng_state=False,
        )

    # selective layer checkpointing...some checks in case we receive '2' or 2...
    elif ac_mode == "selective":
        """enables selective checkpointing of candidate layers.
        Usage:
        'selective_ac_option' with a positive 'int' value in config controls which layers to checkpoint.
        1 == checkpointing every one (all).
        2 == checkpoint every 2nd one
        """
        every_x_layer = int(ac_style)

        if not (every_x_layer >= 0):
            raise ValueError(
                f"Selective layer AC policy (every_x_layer) expects a positive integer, received {every_x_layer}"
            )

        checkpoint_wrapper.__dict__.setdefault("_count", 0)

        checkpoint_wrapper._count += 1
        if not every_x_layer or checkpoint_wrapper._count % every_x_layer == 0:
            return ptd_checkpoint_wrapper(
                module,
                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
                checkpoint_fn=checkpoint,
                use_reentrant=False,
                preserve_rng_state=False,
            )
        # skip activation checkpointing and store activations for this layer
        else:
            return module

    else:
        raise NotImplementedError(
            "Unknown AC type or AC config. Only selective op and selective layer ac implemented currently."
        )


[docs]def apply_selective_activation_checkpointing( model: nn.Module, ac_mode: str, ac_option: Optional[Union[int, str]], ) -> None: """Utility to setup activation checkpointing and wrap the model for checkpointing. Args: model (nn.Module): Model to setup activation checkpointing. ac_mode (str): Activation checkpointing mode. ['none', 'full', 'selective'] ac_option (Optional[Union[int, str]]): Activation checkpointing option. If ac_mode is "selective", ac_option can be an integer or a string representing the number of layers to checkpoint. If ac_mode is "selective" and ac_option is "op", then selective op ac is run. If ac_mode is "none" or "full", ac_option is ignored. """ for layer_id, transformer_block in enumerate(model.layers): if ac_mode in ("full", "selective"): transformer_block = checkpoint_wrapper( transformer_block, ac_mode, ac_option, ) model.layers[layer_id] = transformer_block

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources