• Docs >
  • Automatic Mixed Precision package - torch.amp
Shortcuts

Automatic Mixed Precision package - torch.amp

torch.amp provides convenience methods for mixed precision, where some operations use the torch.float32 (float) datatype and other operations use lower precision floating point datatype (lower_precision_fp): torch.float16 (half) or torch.bfloat16. Some ops, like linear layers and convolutions, are much faster in lower_precision_fp. Other ops, like reductions, often require the dynamic range of float32. Mixed precision tries to match each op to its appropriate datatype.

Ordinarily, “automatic mixed precision training” with datatype of torch.float16 uses torch.autocast and torch.cuda.amp.GradScaler together, as shown in the CUDA Automatic Mixed Precision examples and CUDA Automatic Mixed Precision recipe. However, torch.autocast and torch.cuda.amp.GradScaler are modular, and may be used separately if desired. As shown in the CPU example section of torch.autocast, “automatic mixed precision training/inference” on CPU with datatype of torch.bfloat16 only uses torch.autocast.

For CUDA and CPU, APIs are also provided separately:

  • torch.autocast("cuda", args...) is equivalent to torch.cuda.amp.autocast(args...).

  • torch.autocast("cpu", args...) is equivalent to torch.cpu.amp.autocast(args...). For CPU, only lower precision floating point datatype of torch.bfloat16 is supported for now.

torch.autocast and torch.cpu.amp.autocast are new in version 1.10.

Autocasting

class torch.autocast(device_type, dtype=None, enabled=True, cache_enabled=None)[source]

Instances of autocast serve as context managers or decorators that allow regions of your script to run in mixed precision.

In these regions, ops run in an op-specific dtype chosen by autocast to improve performance while maintaining accuracy. See the Autocast Op Reference for details.

When entering an autocast-enabled region, Tensors may be any type. You should not call half() or bfloat16() on your model(s) or inputs when using autocasting.

autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.

Example for CUDA Devices:

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # Enables autocasting for the forward pass (model + loss)
    with torch.autocast(device_type="cuda"):
        output = model(input)
        loss = loss_fn(output, target)

    # Exits the context manager before backward()
    loss.backward()
    optimizer.step()

See the CUDA Automatic Mixed Precision examples for usage (along with gradient scaling) in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).

autocast can also be used as a decorator, e.g., on the forward method of your model:

class AutocastModel(nn.Module):
    ...
    @torch.autocast(device_type="cuda")
    def forward(self, input):
        ...

Floating-point Tensors produced in an autocast-enabled region may be float16. After returning to an autocast-disabled region, using them with floating-point Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) produced in the autocast region back to float32 (or other dtype if desired). If a Tensor from the autocast region is already float32, the cast is a no-op, and incurs no additional overhead. CUDA Example:

# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")

with torch.autocast(device_type="cuda"):
    # torch.mm is on autocast's list of ops that should run in float16.
    # Inputs are float32, but the op runs in float16 and produces float16 output.
    # No manual casts are required.
    e_float16 = torch.mm(a_float32, b_float32)
    # Also handles mixed input types
    f_float16 = torch.mm(d_float32, e_float16)

# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())

CPU Training Example:

# Creates model and optimizer in default precision
model = Net()
optimizer = optim.SGD(model.parameters(), ...)

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
            output = model(input)
            loss = loss_fn(output, target)

        loss.backward()
        optimizer.step()

CPU Inference Example:

# Creates model in default precision
model = Net().eval()

with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
    for input in data:
        # Runs the forward pass with autocasting.
        output = model(input)

CPU Inference Example with Jit Trace:

class TestModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, num_classes)
    def forward(self, x):
        return self.fc1(x)

input_size = 2
num_classes = 2
model = TestModel(input_size, num_classes).eval()

# For now, we suggest to disable the Jit Autocast Pass,
# As the issue: https://github.com/pytorch/pytorch/issues/75956
torch._C._jit_set_autocast_mode(False)

with torch.cpu.amp.autocast(cache_enabled=False):
    model = torch.jit.trace(model, torch.randn(1, input_size))
model = torch.jit.freeze(model)
# Models Run
for _ in range(3):
    model(torch.randn(1, input_size))

Type mismatch errors in an autocast-enabled region are a bug; if this is what you observe, please file an issue.

autocast(enabled=False) subregions can be nested in autocast-enabled regions. Locally disabling autocast can be useful, for example, if you want to force a subregion to run in a particular dtype. Disabling autocast gives you explicit control over the execution type. In the subregion, inputs from the surrounding region should be cast to dtype before use:

# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")

with torch.autocast(device_type="cuda"):
    e_float16 = torch.mm(a_float32, b_float32)
    with torch.autocast(device_type="cuda", enabled=False):
        # Calls e_float16.float() to ensure float32 execution
        # (necessary because e_float16 was created in an autocasted region)
        f_float32 = torch.mm(c_float32, e_float16.float())

    # No manual casts are required when re-entering the autocast-enabled region.
    # torch.mm again runs in float16 and produces float16 output, regardless of input types.
    g_float16 = torch.mm(d_float32, f_float32)

The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator must be invoked in that thread. This affects torch.nn.DataParallel and torch.nn.parallel.DistributedDataParallel when used with more than one GPU per process (see Working with Multiple GPUs).

Parameters
  • device_type (str, required) – Device type to use. Possible values are: ‘cuda’, ‘cpu’, ‘xpu’ and ‘hpu’. The type is the same as the type attribute of a torch.device. Thus, you may obtain the device type of a tensor using Tensor.device.type.

  • enabled (bool, optional) – Whether autocasting should be enabled in the region. Default: True

  • dtype (torch_dtype, optional) – Whether to use torch.float16 or torch.bfloat16.

  • cache_enabled (bool, optional) – Whether the weight cache inside autocast should be enabled. Default: True

class torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True)[source]

See torch.autocast.

torch.cuda.amp.autocast(args...) is equivalent to torch.autocast("cuda", args...)

torch.cuda.amp.custom_fwd(fwd=None, *, cast_inputs=None)[source]

Create a helper decorator for forward methods of custom autograd functions.

Autograd functions are subclasses of torch.autograd.Function. See the example page for more detail.

Parameters

cast_inputs (torch.dtype or None, optional, default=None) – If not None, when forward runs in an autocast-enabled region, casts incoming floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), then executes forward with autocast disabled. If None, forward’s internal ops execute with the current autocast state.

Note

If the decorated forward is called outside an autocast-enabled region, custom_fwd is a no-op and cast_inputs has no effect.

torch.cuda.amp.custom_bwd(bwd)[source]

Create a helper decorator for backward methods of custom autograd functions.

Autograd functions are subclasses of torch.autograd.Function. Ensures that backward executes with the same autocast state as forward. See the example page for more detail.

class torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True)[source]

See torch.autocast. torch.cpu.amp.autocast(args...) is equivalent to torch.autocast("cpu", args...)

Gradient Scaling

If the forward pass for a particular op has float16 inputs, the backward pass for that op will produce float16 gradients. Gradient values with small magnitudes may not be representable in float16. These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.

To prevent underflow, “gradient scaling” multiplies the network’s loss(es) by a scale factor and invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are then scaled by the same factor. In other words, gradient values have a larger magnitude, so they don’t flush to zero.

Each parameter’s gradient (.grad attribute) should be unscaled before the optimizer updates the parameters, so the scale factor does not interfere with the learning rate.

Note

AMP/fp16 may not work for every model! For example, most bf16-pretrained models cannot operate in the fp16 numerical range of max 65504 and will cause gradients to overflow instead of underflow. In this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number representable in the fp16 dynamic range. While one may expect the scale to always be above 1, our GradScaler does NOT make this guarantee to maintain performance. If you encounter NaNs in your loss or gradients when running with AMP/fp16, verify your model is compatible.

class torch.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)[source]

An instance scaler of GradScaler.

Helps perform the steps of gradient scaling conveniently.

  • scaler.scale(loss) multiplies a given loss by scaler’s current scale factor.

  • scaler.step(optimizer) safely unscales gradients and calls optimizer.step().

  • scaler.update() updates scaler’s scale factor.

Example:

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()

        # scaler.step() first unscales gradients of the optimizer's params.
        # If gradients don't contain infs/NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

See the Automatic Mixed Precision examples for usage (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, and multiple losses/optimizers.

scaler dynamically estimates the scale factor each iteration. To minimize gradient underflow, a large scale factor should be used. However, float16 values can “overflow” (become inf or NaN) if the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used without incurring inf or NaN gradient values. scaler approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every scaler.step(optimizer) (or optional separate scaler.unscale_(optimizer), see unscale_()).

  • If infs/NaNs are found, scaler.step(optimizer) skips the underlying optimizer.step() (so the params themselves remain uncorrupted) and update() multiplies the scale by backoff_factor.

  • If no infs/NaNs are found, scaler.step(optimizer) runs the underlying optimizer.step() as usual. If growth_interval unskipped iterations occur consecutively, update() multiplies the scale by growth_factor.

The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its value calibrates. scaler.step will skip the underlying optimizer.step() for these iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).

Parameters
  • init_scale (float, optional, default=2.**16) – Initial scale factor.

  • growth_factor (float, optional, default=2.0) – Factor by which the scale is multiplied during update() if no inf/NaN gradients occur for growth_interval consecutive iterations.

  • backoff_factor (float, optional, default=0.5) – Factor by which the scale is multiplied during update() if inf/NaN gradients occur in an iteration.

  • growth_interval (int, optional, default=2000) – Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by growth_factor.

  • enabled (bool, optional) – If False, disables gradient scaling. step() simply invokes the underlying optimizer.step(), and other methods become no-ops. Default: True

get_backoff_factor()[source]

Return a Python float containing the scale backoff factor.

Return type

float

get_growth_factor()[source]

Return a Python float containing the scale growth factor.

Return type

float

get_growth_interval()[source]

Return a Python int containing the growth interval.

Return type

int

get_scale()[source]

Return a Python float containing the current scale, or 1.0 if scaling is disabled.

Warning

get_scale() incurs a CPU-GPU sync.

Return type

float

is_enabled()[source]

Return a bool indicating whether this instance is enabled.

Return type

bool

load_state_dict(state_dict)[source]

Load the scaler state.

If this instance is disabled, load_state_dict() is a no-op.

Parameters

state_dict (dict) – scaler state. Should be an object returned from a call to state_dict().

scale(outputs: Tensor) Tensor[source]
scale(outputs: List[Tensor]) List[Tensor]
scale(outputs: Tuple[Tensor, ...]) Tuple[Tensor, ...]
scale(outputs: Iterable[Tensor]) Iterable[Tensor]

Multiplies (‘scales’) a tensor or list of tensors by the scale factor.

Returns scaled outputs. If this instance of GradScaler is not enabled, outputs are returned unmodified.

Parameters

outputs (Tensor or iterable of Tensors) – Outputs to scale.

set_backoff_factor(new_factor)[source]

Set a new scale backoff factor.

Parameters

new_scale (float) – Value to use as the new scale backoff factor.

set_growth_factor(new_factor)[source]

Set a new scale growth factor.

Parameters

new_scale (float) – Value to use as the new scale growth factor.

set_growth_interval(new_interval)[source]

Set a new growth interval.

Parameters

new_interval (int) – Value to use as the new growth interval.

state_dict()[source]

Return the state of the scaler as a dict.

It contains five entries:

  • "scale" - a Python float containing the current scale

  • "growth_factor" - a Python float containing the current growth factor

  • "backoff_factor" - a Python float containing the current backoff factor

  • "growth_interval" - a Python int containing the current growth interval

  • "_growth_tracker" - a Python int containing the number of recent consecutive unskipped steps.

If this instance is not enabled, returns an empty dict.

Note

If you wish to checkpoint the scaler’s state after a particular iteration, state_dict() should be called after update().

Return type

Dict[str, Any]

step(optimizer, *args, **kwargs)[source]

Invoke unscale_(optimizer) followed by parameter update, if gradients are not infs/NaN.

step() carries out the following two operations:

  1. Internally invokes unscale_(optimizer) (unless unscale_() was explicitly called for optimizer earlier in the iteration). As part of the unscale_(), gradients are checked for infs/NaNs.

  2. If no inf/NaN gradients are found, invokes optimizer.step() using the unscaled gradients. Otherwise, optimizer.step() is skipped to avoid corrupting the params.

*args and **kwargs are forwarded to optimizer.step().

Returns the return value of optimizer.step(*args, **kwargs).

Parameters
  • optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.

  • args (Any) – Any arguments.

  • kwargs (Any) – Any keyword arguments.

Return type

Optional[float]

Warning

Closure use is not currently supported.

unscale_(optimizer)[source]

Divides (“unscales”) the optimizer’s gradient tensors by the scale factor.

unscale_() is optional, serving cases where you need to modify or inspect gradients between the backward pass(es) and step(). If unscale_() is not called explicitly, gradients will be unscaled automatically during step().

Simple example, using unscale_() to enable clipping of unscaled gradients:

...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Parameters

optimizer (torch.optim.Optimizer) – Optimizer that owns the gradients to be unscaled.

Note

unscale_() does not incur a CPU-GPU sync.

Warning

unscale_() should only be called once per optimizer per step() call, and only after all gradients for that optimizer’s assigned parameters have been accumulated. Calling unscale_() twice for a given optimizer between each step() triggers a RuntimeError.

Warning

unscale_() may unscale sparse gradients out of place, replacing the .grad attribute.

update(new_scale=None)[source]

Update the scale factor.

If any optimizer steps were skipped the scale is multiplied by backoff_factor to reduce it. If growth_interval unskipped iterations occurred consecutively, the scale is multiplied by growth_factor to increase it.

Passing new_scale sets the new scale value manually. (new_scale is not used directly, it’s used to fill GradScaler’s internal scale tensor. So if new_scale was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.)

Parameters

new_scale (float or torch.cuda.FloatTensor, optional, default=None) – New scale factor.

Warning

update() should only be called at the end of the iteration, after scaler.step(optimizer) has been invoked for all optimizers used this iteration.

Warning

For performance reasons, we do not check the scale factor value to avoid synchronizations, so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or you are seeing NaNs in your gradients or loss, something is likely wrong. For example, bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.

Autocast Op Reference

Op Eligibility

Ops that run in float64 or non-floating-point dtypes are not eligible, and will run in these types whether or not autocast is enabled.

Only out-of-place ops and Tensor methods are eligible. In-place variants and calls that explicitly supply an out=... Tensor are allowed in autocast-enabled regions, but won’t go through autocasting. For example, in an autocast-enabled region a.addmm(b, c) can autocast, but a.addmm_(b, c) and a.addmm(b, c, out=d) cannot. For best performance and stability, prefer out-of-place ops in autocast-enabled regions.

Ops called with an explicit dtype=... argument are not eligible, and will produce output that respects the dtype argument.

CUDA Op-Specific Behavior

The following lists describe the behavior of eligible ops in autocast-enabled regions. These ops always go through autocasting whether they are invoked as part of a torch.nn.Module, as a function, or as a torch.Tensor method. If functions are exposed in multiple namespaces, they go through autocasting regardless of the namespace.

Ops not listed below do not go through autocasting. They run in the type defined by their inputs. However, autocasting may still change the type in which unlisted ops run if they’re downstream from autocasted ops.

If an op is unlisted, we assume it’s numerically stable in float16. If you believe an unlisted op is numerically unstable in float16, please file an issue.

CUDA Ops that can autocast to float16

__matmul__, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, prelu, RNNCell

CUDA Ops that can autocast to float32

__pow__, __rdiv__, __rpow__, __rtruediv__, acos, asin, binary_cross_entropy_with_logits, cosh, cosine_embedding_loss, cdist, cosine_similarity, cross_entropy, cumprod, cumsum, dist, erfinv, exp, expm1, group_norm, hinge_embedding_loss, kl_div, l1_loss, layer_norm, log, log_softmax, log10, log1p, log2, margin_ranking_loss, mse_loss, multilabel_margin_loss, multi_margin_loss, nll_loss, norm, normalize, pdist, poisson_nll_loss, pow, prod, reciprocal, rsqrt, sinh, smooth_l1_loss, soft_margin_loss, softmax, softmin, softplus, sum, renorm, tan, triplet_margin_loss

CUDA Ops that promote to the widest input type

These ops don’t require a particular dtype for stability, but take multiple inputs and require that the inputs’ dtypes match. If all of the inputs are float16, the op runs in float16. If any of the inputs is float32, autocast casts all inputs to float32 and runs the op in float32.

addcdiv, addcmul, atan2, bilinear, cross, dot, grid_sample, index_put, scatter_add, tensordot

Some ops not listed here (e.g., binary ops like add) natively promote inputs without autocasting’s intervention. If inputs are a mixture of float16 and float32, these ops run in float32 and produce float32 output, regardless of whether autocast is enabled.

Prefer binary_cross_entropy_with_logits over binary_cross_entropy

The backward passes of torch.nn.functional.binary_cross_entropy() (and torch.nn.BCELoss, which wraps it) can produce gradients that aren’t representable in float16. In autocast-enabled regions, the forward input may be float16, which means the backward gradient must be representable in float16 (autocasting float16 forward inputs to float32 doesn’t help, because that cast must be reversed in backward). Therefore, binary_cross_entropy and BCELoss raise an error in autocast-enabled regions.

Many models use a sigmoid layer right before the binary cross entropy layer. In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits() or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are safe to autocast.

CPU Op-Specific Behavior

The following lists describe the behavior of eligible ops in autocast-enabled regions. These ops always go through autocasting whether they are invoked as part of a torch.nn.Module, as a function, or as a torch.Tensor method. If functions are exposed in multiple namespaces, they go through autocasting regardless of the namespace.

Ops not listed below do not go through autocasting. They run in the type defined by their inputs. However, autocasting may still change the type in which unlisted ops run if they’re downstream from autocasted ops.

If an op is unlisted, we assume it’s numerically stable in bfloat16. If you believe an unlisted op is numerically unstable in bfloat16, please file an issue.

CPU Ops that can autocast to bfloat16

conv1d, conv2d, conv3d, bmm, mm, baddbmm, addmm, addbmm, linear, matmul, _convolution

CPU Ops that can autocast to float32

conv_transpose1d, conv_transpose2d, conv_transpose3d, avg_pool3d, binary_cross_entropy, grid_sampler, grid_sampler_2d, _grid_sampler_2d_cpu_fallback, grid_sampler_3d, polar, prod, quantile, nanquantile, stft, cdist, trace, view_as_complex, cholesky, cholesky_inverse, cholesky_solve, inverse, lu_solve, orgqr, inverse, ormqr, pinverse, max_pool3d, max_unpool2d, max_unpool3d, adaptive_avg_pool3d, reflection_pad1d, reflection_pad2d, replication_pad1d, replication_pad2d, replication_pad3d, mse_loss, ctc_loss, kl_div, multilabel_margin_loss, fft_fft, fft_ifft, fft_fft2, fft_ifft2, fft_fftn, fft_ifftn, fft_rfft, fft_irfft, fft_rfft2, fft_irfft2, fft_rfftn, fft_irfftn, fft_hfft, fft_ihfft, linalg_matrix_norm, linalg_cond, linalg_matrix_rank, linalg_solve, linalg_cholesky, linalg_svdvals, linalg_eigvals, linalg_eigvalsh, linalg_inv, linalg_householder_product, linalg_tensorinv, linalg_tensorsolve, fake_quantize_per_tensor_affine, eig, geqrf, lstsq, _lu_with_info, qr, solve, svd, symeig, triangular_solve, fractional_max_pool2d, fractional_max_pool3d, adaptive_max_pool3d, multilabel_margin_loss_forward, linalg_qr, linalg_cholesky_ex, linalg_svd, linalg_eig, linalg_eigh, linalg_lstsq, linalg_inv_ex

CPU Ops that promote to the widest input type

These ops don’t require a particular dtype for stability, but take multiple inputs and require that the inputs’ dtypes match. If all of the inputs are bfloat16, the op runs in bfloat16. If any of the inputs is float32, autocast casts all inputs to float32 and runs the op in float32.

cat, stack, index_copy

Some ops not listed here (e.g., binary ops like add) natively promote inputs without autocasting’s intervention. If inputs are a mixture of bfloat16 and float32, these ops run in float32 and produce float32 output, regardless of whether autocast is enabled.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources