• Docs >
  • Utils >
  • torchtnt.utils.flops.FlopTensorDispatchMode
Shortcuts

torchtnt.utils.flops.FlopTensorDispatchMode

class torchtnt.utils.flops.FlopTensorDispatchMode(module: Module)

A context manager to measure flops of a module. Requires PyTorch 1.13+.

Flop count implementation based on https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505

Examples:

>>> import copy
>>> import torch
>>> import torchvision.models as models
>>> from torchtnt.utils.flops import FlopTensorDispatchMode

>>> module = models.resnet18()
>>> module_input = torch.randn(1, 3, 224, 224)
>>> with FlopTensorDispatchMode(module) as ftdm:
>>>     # count forward flops
>>>     res = module(module_input).mean()
>>>     flops_forward = copy.deepcopy(ftdm.flop_counts)

>>>     # reset count before counting backward flops
>>>     ftdm.reset()
>>>     res.backward()
>>>     flops_backward = copy.deepcopy(ftdm.flop_counts)
__init__(module: Module) None

Initializes a FlopTensorDispatchMode context manager object.

Parameters:module – The module to count flops on.

Methods

__init__(module) Initializes a FlopTensorDispatchMode context manager object.
push(*args, **kwargs)
reset() Resets current flop count.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources