Shortcuts

torch.dynamic-shape

cond_branch_class_method

Note

Tags: torch.cond, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)

class CondBranchClassMethod(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables


    This example demonstrates using class method in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self) -> None:
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

example_args = (torch.randn(3),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondBranchClassMethod()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3]"):
                 sin: "f32[3]" = torch.ops.aten.sin.default(x);  x = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {}

cond_branch_nested_function

Note

Tags: torch.cond, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class CondBranchNestedFunction(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables

    This example demonstrates using nested function in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x):
        def true_fn(x):
            def inner_true_fn(y):
                return x + y

            return inner_true_fn(x)

        def false_fn(x):
            def inner_false_fn(y):
                return x - y

            return inner_false_fn(x)

        return cond(x.shape[0] < 10, true_fn, false_fn, [x])

example_args = (torch.randn(3),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondBranchNestedFunction()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3]"):
                 add: "f32[3]" = torch.ops.aten.add.Tensor(x, x);  x = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

cond_branch_nonlocal_variables

Note

Tags: torch.cond, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class CondBranchNonlocalVariables(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
    - both branches must take the same args, which must also match the branch args passed to cond.
    - both branches must return a single tensor
    - returned tensor must have the same tensor metadata, e.g. shape and dtype
    - branch function can be free function, nested function, lambda, class methods
    - branch function can not have closure variables
    - no inplace mutations on inputs or global variables

    This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.

    The code below will not work because capturing closure variables is not supported.
    ```
    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y + my_tensor_var + my_primitive_var

    def false_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y - my_tensor_var - my_primitive_var

    return cond(x.shape[0] > 5, true_fn, false_fn, [x])
    ```

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x):
        my_tensor_var = x + 100
        my_primitive_var = 3.14

        def true_fn(x, y, z):
            return x + y + z

        def false_fn(x, y, z):
            return x - y - z

        return cond(
            x.shape[0] > 5,
            true_fn,
            false_fn,
            [x, my_tensor_var, torch.tensor(my_primitive_var)],
        )

example_args = (torch.randn(6),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondBranchNonlocalVariables()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, c_lifted_tensor_0: "f32[]", x: "f32[6]"):
                 add: "f32[6]" = torch.ops.aten.add.Tensor(x, 100)

                 lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
            detach_: "f32[]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None

                 add_1: "f32[6]" = torch.ops.aten.add.Tensor(x, add);  x = add = None
            add_2: "f32[6]" = torch.ops.aten.add.Tensor(add_1, detach_);  add_1 = detach_ = None
            return (add_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lifted_tensor_0'), target='lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {}

cond_operands

Note

Tags: torch.cond, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

from torch.export import Dim

x = torch.randn(3, 2)
y = torch.randn(2)
dim0_x = Dim("dim0_x")

class CondOperands(torch.nn.Module):
    """
    The operands passed to cond() must be:
    - a list of tensors
    - match arguments of `true_fn` and `false_fn`

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x, y):
        def true_fn(x, y):
            return x + y

        def false_fn(x, y):
            return x - y

        return torch.cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

example_args = (x, y)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
extra_inputs = (torch.randn(2, 2), torch.randn(2))
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
model = CondOperands()


torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
             #
            sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)

                 gt: "Sym(s0 > 2)" = sym_size_int_1 > 2;  sym_size_int_1 = None

                 true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, y]);  gt = true_graph_0 = false_graph_0 = x = y = None
            getitem: "f32[s0, 2]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
                         add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
                return (add,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
                         sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(x, y);  x = y = None
                return (sub,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {s0: VR[0, int_oo]}

cond_predicate

Note

Tags: torch.cond, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class CondPredicate(torch.nn.Module):
    """
    The conditional statement (aka predicate) passed to cond() must be one of the following:
      - torch.Tensor with a single element
      - boolean expression

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x):
        pred = x.dim() > 2 and x.shape[2] > 10

        return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

example_args = (torch.randn(6, 4, 3),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondPredicate()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[6, 4, 3]"):
                 sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(x);  x = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {}

dynamic_shape_constructor

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

class DynamicShapeConstructor(torch.nn.Module):
    """
    Tensor constructors should be captured with dynamic shape inputs rather
    than being baked in with static shape.
    """

    def forward(self, x):
        return torch.zeros(x.shape[0] * 2)

example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeConstructor()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 zeros: "f32[6]" = torch.ops.aten.zeros.default([6], device = device(type='cpu'), pin_memory = False)
            return (zeros,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)])
Range constraints: {}

dynamic_shape_if_guard

Note

Tags: python.control-flow, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

class DynamicShapeIfGuard(torch.nn.Module):
    """
    `if` statement with backed dynamic shape predicate will be specialized into
    one particular branch and generate a guard. However, export will fail if the
    the dimension is marked as dynamic shape from higher level API.
    """

    def forward(self, x):
        if x.shape[0] == 3:
            return x.cos()

        return x.sin()

example_args = (torch.randn(3, 2, 2),)
tags = {"torch.dynamic-shape", "python.control-flow"}
model = DynamicShapeIfGuard()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2, 2]"):
                 cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x);  x = None
            return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

dynamic_shape_map

Note

Tags: torch.map, torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import map

class DynamicShapeMap(torch.nn.Module):
    """
    functorch map() maps a function over the first tensor dimension.
    """

    def forward(self, xs, y):
        def body(x, y):
            return x + y

        return map(body, xs, y)

example_args = (torch.randn(3, 2), torch.randn(2))
tags = {"torch.dynamic-shape", "torch.map"}
model = DynamicShapeMap()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[3, 2]", y: "f32[2]"):
                 body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y]);  body_graph_0 = xs = y = None
            getitem: "f32[3, 2]" = map_impl[0];  map_impl = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[2]", y: "f32[2]"):
                         add: "f32[2]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None
                return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

dynamic_shape_round

Note

Tags: torch.dynamic-shape, python.builtin

Support Level: NOT_SUPPORTED_YET

Original source code:

# mypy: allow-untyped-defs
import torch

from torch._export.db.case import SupportLevel
from torch.export import Dim

class DynamicShapeRound(torch.nn.Module):
    """
    Calling round on dynamic shapes is not supported.
    """

    def forward(self, x):
        return x[: round(x.shape[0] / 2)]

x = torch.randn(3, 2)
dim0_x = Dim("dim0_x")
example_args = (x,)
tags = {"torch.dynamic-shape", "python.builtin"}
support_level = SupportLevel.NOT_SUPPORTED_YET
dynamic_shapes = {"x": {0: dim0_x}}
model = DynamicShapeRound()


torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

Unsupported: Constraints violated (dim0_x)! For more information, run with TORCH_LOGS="+dynamic".

dynamic_shape_slicing

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

class DynamicShapeSlicing(torch.nn.Module):
    """
    Slices with dynamic shape arguments should be captured into the graph
    rather than being baked in.
    """

    def forward(self, x):
        return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]

example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeSlicing()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 1);  x = None
            slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2);  slice_1 = None
            return (slice_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
Range constraints: {}

dynamic_shape_view

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

class DynamicShapeView(torch.nn.Module):
    """
    Dynamic shapes should be propagated to view arguments instead of being
    baked into the exported graph.
    """

    def forward(self, x):
        new_x_shape = x.size()[:-1] + (2, 5)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1)

example_args = (torch.randn(10, 10),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeView()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]"):
                 view: "f32[10, 2, 5]" = torch.ops.aten.view.default(x, [10, 2, 5]);  x = None

                 permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
            return (permute,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)])
Range constraints: {}

list_contains

Note

Tags: torch.dynamic-shape, python.assert, python.data-structure

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

class ListContains(torch.nn.Module):
    """
    List containment relation can be checked on a dynamic shape or constants.
    """

    def forward(self, x):
        assert x.size(-1) in [6, 2]
        assert x.size(0) not in [4, 5, 6]
        assert "monkey" not in ["cow", "pig"]
        return x + x

example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"}
model = ListContains()


torch.export.export(model, example_args)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, x);  x = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

scalar_output

Note

Tags: torch.dynamic-shape

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch

from torch.export import Dim

x = torch.randn(3, 2)
dim1_x = Dim("dim1_x")

class ScalarOutput(torch.nn.Module):
    """
    Returning scalar values from the graph is supported, in addition to Tensor
    outputs. Symbolic shapes are captured and rank is specialized.
    """
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return x.shape[1] + 1

example_args = (x,)
tags = {"torch.dynamic-shape"}
dynamic_shapes = {"x": {1: dim1_x}}
model = ScalarOutput()


torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, s0]"):
             #
            sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 1);  x = None

                 add: "Sym(s0 + 1)" = sym_size_int_1 + 1;  sym_size_int_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)])
Range constraints: {s0: VR[0, int_oo]}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources