Shortcuts

Fake Tensor

Fake tensors, similar to meta tensors, carry no data; however, unlike meta tensors which report meta as their device, fake tensors act as if they were allocated on a real device. The following example shows how the two tensors types differ:

>>> import torch
>>>
>>> from torchdistx.fake import fake_mode
>>>
>>> # Meta tensors are always "allocated" on the `meta` device.
>>> a = torch.ones([10], device="meta")
>>> a
tensor(..., device='meta', size(10,))
>>> a.device
device(type='meta')
>>>
>>> # Fake tensors are always "allocated" on the specified device.
>>> with fake_mode():
...     b = torch.ones([10])
...
>>> b
tensor(..., size(10,), fake=True)
>>> b.device
device(type='cpu')

Fake tensors, like meta tensors, rely on the meta backend for their operation. In that sense meta tensors and fake tensors can be considered close cousins. Fake tensors are just an alternative interface to the meta backend and have mostly the same tradeoffs as meta tensors.

API

The API consists mainly of the fake_mode() function that acts as a Python context manager. Any tensor constructed within its scope will be forced to be fake.

torchdistx.fake.fake_mode()[source]

Instantiates all tensors within its context as fake.

Return type

Iterator[None]

There are also two convenience functions offered as part of the API:

torchdistx.fake.is_fake(tensor)[source]

Indicates whether tensor is fake.

Parameters

tensor (torch.Tensor) – The tensor to check.

Return type

bool

torchdistx.fake.meta_like(fake)[source]

Returns a meta tensor with the same properties as fake.

This function has the same Autograd behavior as detach() meaning the returned tensor won’t be part of the Autograd graph.

Parameters

fake (torch.Tensor) – The fake tensor to copy from.

Return type

torch.Tensor

Use Cases

Fake tensors were originally meant as a building block for Deferred Module Initialization. However they are not necessarily bound to that use case and can also be used for other purposes. For instance they serve as a surprisingly good learning tool for inspecting large model architectures that cannot fit on a consumer-grade PC:

>>> import torch
>>>
>>> from transformers import BlenderbotModel, BlenderbotConfig
>>>
>>> from torchdistx.fake import fake_mode
>>>
>>> # Instantiate Blenderbot on a personal laptop with 8GB RAM.
>>> with fake_mode():
...     m = BlenderbotModel(BlenderbotConfig())
...
>>> # Check out the model layers and their parameters.
>>> m
BlenderbotModel(...)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources