Shortcuts

torch.nn.utils.parametrize.cached

torch.nn.utils.parametrize.cached()[source]

Context manager that enables the caching system within parametrizations registered with register_parametrization().

The value of the parametrized objects is computed and cached the first time they are required when this context manager is active. The cached values are discarded when leaving the context manager.

This is useful when using a parametrized parameter more than once in the forward pass. An example of this is when parametrizing the recurrent kernel of an RNN or when sharing weights.

The simplest way to activate the cache is by wrapping the forward pass of the neural network

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

in training and evaluation. One may also wrap the parts of the modules that use several times the parametrized tensors. For example, the loop of an RNN with a parametrized recurrent kernel:

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources