# functorch.experimental.functionalize¶

functorch.experimental.functionalize(func, *, remove='mutations')[source]

functionalize is a transform that can be used to remove (intermediate) mutations and aliasing from a function, while preserving the function’s semantics.

functionalize(func) returns a new function with the same semantics as func, but with all intermediate mutations removed. Every inplace operation performed on an intermediate tensor: intermediate.foo_() gets replaced by its out-of-place equivalent: intermediate_updated = intermediate.foo().

functionalize is useful for shipping a pytorch program off to backends or compilers that aren’t able to easily represent mutations or aliasing operators.

Parameters
• func (Callable) – A Python function that takes one or more arguments.

• remove (str) – An optional string argument, that takes on either the value ‘mutations’ or ‘mutations_and_views’. If ‘mutations’ is passed in then all mutating operators will be replaced with their non-mutating equivalents. If ‘mutations_and_views’ is passed in, then additionally, all aliasing operators will be replaced with their non-aliasing equivalents. Default: ‘mutations’.

Returns

Returns a new “functionalized” function. It takes the same inputs as func, and has the same behavior, but any mutations (and optionally aliasing) performed on intermeidate tensors in the function will be removed.

functionalize will also remove mutations (and views) that were performed on function inputs. However to preserve semantics, functionalize will “fix up” the mutations after the transform has finished running, by detecting if any tensor inputs “should have” been mutated, and copying the new data back to the inputs if necessary.

Example:

>>> import torch
>>> from functorch import make_fx
>>> from functorch.experimental import functionalize
>>>
>>> A function that uses mutations and views, but only on intermediate tensors.
>>> def f(a):
...     b = a + 1
...     c = b.view(-1)
...     return b
...
>>> inpt = torch.randn(2)
>>>
>>> out1 = f(inpt)
>>> out2 = functionalize(f)(inpt)
>>>
>>> # semantics are the same (outputs are equivalent)
>>> print(torch.allclose(out1, out2))
True
>>>
>>> f_traced = make_fx(f)(inpt)
>>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>>
>>> print(f_traced.code)

def forward(self, a_1):

>>> print(f_no_mutations_traced.code)

def forward(self, a_1):
return view_1

>>> print(f_no_mutations_and_views_traced.code)

def forward(self, a_1):
return view_copy_1

>>> A function that mutates its input tensor
>>> def f(a):
...     b = a.view(-1)
...     return a
...
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>>
>>> All mutations and views have been removed,
>>> but there is an extra copy_ in the graph to correctly apply the mutation to the input
>>> after the function has completed.
>>> print(f_no_mutations_and_views_traced.code)

def forward(self, a_1):
view_copy = torch.ops.aten.view_copy(a_1, [-1])

3. resize_() has some limitations: functionalize will only work on programs that use resize_() as long as the tensor being resized is not a view.
Finally, a helpful mental model for understanding functionalization is that most user pytorch programs are writting with the public torch API. When executed, torch operators are generally decomposed into our internal C++ “ATen” API. The logic for functionalization happens entirely at the level of ATen. Functionalization knows how to take every aliasing operator in ATen, and map it to its non-aliasing equivalent (e.g. tensor.view({-1}) -> at::view_copy(tensor, {-1})), and how to take every mutating operator in ATen, and map it to its non-mutating equivalent (e.g. tensor.add_(1) -> at::add(tensor, -1)`), while tracking aliases and mutations out-of-line to know when to fix things up. Information about which ATen operators are aliasing or mutating all comes from https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.