torchtnt.utils.flops.FlopTensorDispatchMode¶
-
class
torchtnt.utils.flops.
FlopTensorDispatchMode
(module: Module)¶ A context manager to measure flops of a module. Requires PyTorch 1.13+.
Flop count implementation based on https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
Examples:
>>> import copy >>> import torch >>> import torchvision.models as models >>> from torchtnt.utils.flops import FlopTensorDispatchMode >>> module = models.resnet18() >>> module_input = torch.randn(1, 3, 224, 224) >>> with FlopTensorDispatchMode(module) as ftdm: >>> # count forward flops >>> res = module(module_input).mean() >>> flops_forward = copy.deepcopy(ftdm.flop_counts) >>> # reset count before counting backward flops >>> ftdm.reset() >>> res.backward() >>> flops_backward = copy.deepcopy(ftdm.flop_counts)
-
__init__
(module: Module) None ¶ Initializes a FlopTensorDispatchMode context manager object.
Parameters: module – The module to count flops on.
Methods
__init__
(module)Initializes a FlopTensorDispatchMode context manager object. push
(*args, **kwargs)reset
()Resets current flop count. -