Shortcuts

Source code for torch.cuda.amp.grad_scaler

import torch
from torch.amp.grad_scaler import OptState

__all__ = ["GradScaler", "OptState"]


[docs]class GradScaler(torch.amp.GradScaler): r""" See :class:`torch.amp.GradScaler`. ``torch.cuda.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cuda", args...)`` """ def __init__( self, init_scale: float = 2.0**16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True, ) -> None: super().__init__( "cuda", init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources