torch.escape-hatch ====================== assume_constant_result ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch <torch.escape-hatch>` Support Level: SUPPORTED Original source code: .. code-block:: python # mypy: allow-untyped-defs import torch import torch._dynamo as torchdynamo class AssumeConstantResult(torch.nn.Module): """ Applying `assume_constant_result` decorator to burn make non-tracable code as constant. """ @torchdynamo.assume_constant_result def get_item(self, y): return y.int().item() def forward(self, x, y): return x[: self.get_item(y)] example_args = (torch.randn(3, 2), torch.tensor(4)) tags = {"torch.escape-hatch"} model = AssumeConstantResult() torch.export.export(model, example_args) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", y: "i64[]"): slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 4); x = None return (slice_1,) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs slice_1: USER_OUTPUT Range constraints: {} constrain_as_size_example ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-value <torch.dynamic-value>`, :doc:`torch.escape-hatch <torch.escape-hatch>` Support Level: SUPPORTED Original source code: .. code-block:: python # mypy: allow-untyped-defs import torch class ConstrainAsSizeExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check_is_size is used for values that NEED to be used for constructing tensor. """ def forward(self, x): a = x.item() torch._check_is_size(a) torch._check(a <= 5) return torch.zeros((a, 5)) example_args = (torch.tensor(4),) tags = { "torch.dynamic-value", "torch.escape-hatch", } model = ConstrainAsSizeExample() torch.export.export(model, example_args) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]"): item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None # sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None ge_1: "Sym(u0 >= 0)" = item >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1: "Sym(u0 <= 5)" = item <= 5 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False); item = None return (zeros,) Graph signature: # inputs x: USER_INPUT # outputs zeros: USER_OUTPUT Range constraints: {u0: VR[0, 5], u1: VR[0, 5]} constrain_as_value_example ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-value <torch.dynamic-value>`, :doc:`torch.escape-hatch <torch.escape-hatch>` Support Level: SUPPORTED Original source code: .. code-block:: python # mypy: allow-untyped-defs import torch class ConstrainAsValueExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check is used for values that don't need to be used for constructing tensor. """ def forward(self, x, y): a = x.item() torch._check(a >= 0) torch._check(a <= 5) if a < 6: return y.sin() return y.cos() example_args = (torch.tensor(4), torch.randn(5, 5)) tags = { "torch.dynamic-value", "torch.escape-hatch", } model = ConstrainAsValueExample() torch.export.export(model, example_args) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]", y: "f32[5, 5]"): item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None ge_1: "Sym(u0 >= 0)" = item >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1: "Sym(u0 <= 5)" = item <= 5; item = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None sin: "f32[5, 5]" = torch.ops.aten.sin.default(y); y = None return (sin,) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs sin: USER_OUTPUT Range constraints: {u0: VR[0, 5], u1: VR[0, 5]}