Shortcuts

torch.dynamic-value

constrain_as_size_example

Note

Tags: torch.dynamic-value, torch.escape-hatch

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch



class ConstrainAsSizeExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check_is_size is used for values that NEED to be used for constructing
    tensor.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        a = x.item()
        torch._check_is_size(a)
        torch._check(a <= 5)
        return torch.zeros((a, 5))

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]"):
             # File: /opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_export/db/examples/constrain_as_size_example.py:26 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             # 
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item)

            # No stacktrace found for following nodes
            sym_constrain_range = torch.ops.aten.sym_constrain_range.default(item, min = 0, max = 5)
            ge: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression 0 <= u0 on node 'ge'");  ge = None
            le: "Sym(u0 <= 5)" = item <= 5
            _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le = None

             # File: /opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_export/db/examples/constrain_as_size_example.py:29 in forward, code: return torch.zeros((a, 5))
            zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False);  item = None
            return (zeros,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}

constrain_as_value_example

Note

Tags: torch.dynamic-value, torch.escape-hatch

Support Level: SUPPORTED

Original source code:

# mypy: allow-untyped-defs
import torch



class ConstrainAsValueExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check is used for values that don't need to be used for constructing
    tensor.
    """

    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a <= 5)

        if a < 6:
            return y.sin()
        return y.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[5, 5]"):
             # File: /opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_export/db/examples/constrain_as_value_example.py:26 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

            # No stacktrace found for following nodes
            sym_constrain_range = torch.ops.aten.sym_constrain_range.default(item, min = 0, max = 5)
            ge: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression 0 <= u0 on node 'ge_1'");  ge = None
            le: "Sym(u0 <= 5)" = item <= 5;  item = None
            _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le = None

             # File: /opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_export/db/examples/constrain_as_value_example.py:31 in forward, code: return y.sin()
            sin: "f32[5, 5]" = torch.ops.aten.sin.default(y);  y = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources