python.control-flow¶
dynamic_shape_if_guard¶
Original source code:
# mypy: allow-untyped-defs
import torch
class DynamicShapeIfGuard(torch.nn.Module):
"""
`if` statement with backed dynamic shape predicate will be specialized into
one particular branch and generate a guard. However, export will fail if the
the dimension is marked as dynamic shape from higher level API.
"""
def forward(self, x):
if x.shape[0] == 3:
return x.cos()
return x.sin()
example_args = (torch.randn(3, 2, 2),)
tags = {"torch.dynamic-shape", "python.control-flow"}
model = DynamicShapeIfGuard()
torch.export.export(model, example_args)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2, 2]"):
cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x); x = None
return (cos,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}
list_unpack¶
Original source code:
# mypy: allow-untyped-defs
from typing import List
import torch
class ListUnpack(torch.nn.Module):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
def forward(self, args: List[torch.Tensor]):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
x, *y = args
return x + y[0]
example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
tags = {"python.control-flow", "python.data-structure"}
model = ListUnpack()
torch.export.export(model, example_args)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, args_0: "f32[3, 2]", args_1: "i64[]", args_2: "i64[]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(args_0, args_1); args_0 = args_1 = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}
static_for_loop¶
Original source code:
# mypy: allow-untyped-defs
import torch
class StaticForLoop(torch.nn.Module):
"""
A for loop with constant number of iterations should be unrolled in the exported graph.
"""
def forward(self, x):
ret = []
for i in range(10): # constant
ret.append(i + x)
return ret
example_args = (torch.randn(3, 2),)
tags = {"python.control-flow"}
model = StaticForLoop()
torch.export.export(model, example_args)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2]"):
add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 0)
add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1)
add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 2)
add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 3)
add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4)
add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 5)
add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 6)
add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 7)
add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 8)
add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 9); x = None
return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
Range constraints: {}
static_if¶
Original source code:
# mypy: allow-untyped-defs
import torch
class StaticIf(torch.nn.Module):
"""
`if` statement with static predicate value should be traced through with the
taken branch.
"""
def forward(self, x):
if len(x.shape) == 3:
return x + torch.ones(1, 1, 1)
return x
example_args = (torch.randn(3, 2, 2),)
tags = {"python.control-flow"}
model = StaticIf()
torch.export.export(model, example_args)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[3, 2, 2]"):
ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(x, ones); x = ones = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}