Shortcuts

torch.escape-hatch

assume_constant_result

Note

Tags: torch.escape-hatch

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch
import torch._dynamo as torchdynamo



class AssumeConstantResult(torch.nn.Module):
    """
    Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
    """

    def __init__(self):
        super().__init__()

    @torchdynamo.assume_constant_result
    def get_item(self, y):
        return y.int().item()

    def forward(self, x, y):
        return x[: self.get_item(y)]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]", y: "i64[]"):
                slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 4);  x = None
            return (slice_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_1'), target=None)])
Range constraints: {}

constrain_as_size_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch



class ConstrainAsSizeExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check_is_size is used for values that NEED to be used for constructing
    tensor.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        a = x.item()
        torch._check_is_size(a)
        torch._check(a <= 5)
        return torch.zeros((a, 5))

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]"):
                item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

            #
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item)

            # No stacktrace found for following nodes
            sym_constrain_range = torch.ops.aten.sym_constrain_range.default(item, min = 0, max = 5)
            mul: "Sym(-u0)" = -1 * item
            le: "Sym(-u0 <= 0)" = mul <= 0;  mul = None
            _assert_scalar = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u0 <= 0 on node 'le_1'");  le = None
            le_1: "Sym(u0 <= 5)" = item <= 5
            _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_2'");  le_1 = None

                zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False);  item = None
            return (zeros,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}

constrain_as_value_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch



class ConstrainAsValueExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check is used for values that don't need to be used for constructing
    tensor.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a <= 5)

        if a < 6:
            return y.sin()
        return y.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[5, 5]"):
                item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

            # No stacktrace found for following nodes
            sym_constrain_range = torch.ops.aten.sym_constrain_range.default(item, min = 0, max = 5)
            mul: "Sym(-u0)" = -1 * item
            le: "Sym(-u0 <= 0)" = mul <= 0;  mul = None
            _assert_scalar = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u0 <= 0 on node 'le_1'");  le = None
            le_1: "Sym(u0 <= 5)" = item <= 5;  item = None
            _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_2'");  le_1 = None

                sin: "f32[5, 5]" = torch.ops.aten.sin.default(y);  y = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources