python.assert ================= dynamic_shape_assert ^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.assert <python.assert>` Support Level: SUPPORTED Original source code: .. code-block:: python # mypy: allow-untyped-defs import torch class DynamicShapeAssert(torch.nn.Module): """ A basic usage of python assertion. """ def forward(self, x): # assertion with error message assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" # assertion without error message assert x.shape[0] > 1 return x example_args = (torch.randn(3, 2),) tags = {"python.assert"} model = DynamicShapeAssert() torch.export.export(model, example_args) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): return (x,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='x'), target=None)]) Range constraints: {} list_contains ^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.data-structure <python.data-structure>`, :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`python.assert <python.assert>` Support Level: SUPPORTED Original source code: .. code-block:: python # mypy: allow-untyped-defs import torch class ListContains(torch.nn.Module): """ List containment relation can be checked on a dynamic shape or constants. """ def forward(self, x): assert x.size(-1) in [6, 2] assert x.size(0) not in [4, 5, 6] assert "monkey" not in ["cow", "pig"] return x + x example_args = (torch.randn(3, 2),) tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"} model = ListContains() torch.export.export(model, example_args) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, x); x = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}