Shortcuts

Source code for torch.export.dynamic_shapes

import builtins
import dataclasses
import inspect
import math
import sys
import weakref
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union

import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import SUPPORTED_NODES

from .exported_program import ExportedProgram

if TYPE_CHECKING:
    from sympy import Symbol

    from torch._guards import Source

    from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint

__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"]


class _Dim(type):
    """
    Metaclass for :func:`Dim` types.
    """

    @staticmethod
    def readable(name, min_, max_):
        if min_ == 2:
            min_ = None
        if max_ == sys.maxsize - 1:
            max_ = None
        if min_ is None and max_ is None:
            return f"Dim('{name}')"
        if min_ is None:
            return f"Dim('{name}', max={max_})"
        if max_ is None:
            return f"Dim('{name}', min={min_})"
        return f"Dim('{name}', min={min_}, max={max_})"

    def __add__(cls, other):
        # e.g., dim + 1
        if type(other) is not int:
            raise NotImplementedError(
                f"Attempted to add {other} to {cls.__name__}, where an integer was expected. "
                "(Only increasing linear operations with integer coefficients are supported.)"
            )
        return cls._derive(lambda x: x + other)

    def __radd__(cls, other):
        return cls + other

    def __sub__(cls, other):
        # e.g., dim - 1
        if type(other) is not int:
            raise NotImplementedError(
                f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. "
                "(Only increasing linear operations with integer coefficients are supported.)"
            )
        return cls._derive(lambda x: x - other)

    def __rsub__(cls, other):
        raise NotImplementedError(
            f"Attempted to negate {cls.__name__}. "
            "(Only increasing linear operations with integer coefficients are supported.)"
        )

    def __mul__(cls, other):
        # e.g., dim * 2
        if type(other) is not int or other <= 0:
            raise NotImplementedError(
                f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. "
                "(Only increasing linear operations with integer coefficients are supported.)"
            )
        return cls._derive(lambda x: x * other)

    def __rmul__(cls, other):
        return cls * other

    def _derived_name(cls, fn):
        from sympy import sympify

        return str(fn(sympify(cls.__name__)))

    def _derive(cls, fn):
        return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn})


class _DerivedDim(_Dim):
    """
    Metaclass for derived :func:`Dim` types.

    Currently we only support increasing linear expressions with integer coefficients.
    In other words, a derived Dim can always be written in the form Ax + B, where
    x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive.
    (In particular, the latter ensures that x < y => Ax + B < Ay + B.)
    These restrictions on the form of derived Dims makes the metatheory simpler: e.g.,
    it simplifies computing ranges for derived Dims, solving for underlying regular Dims,
    deciding equalities between derived Dims, and so on.

    The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`.
    The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
    """

    @property
    def min(self):
        # assume that self.fn is an increasing function
        # TODO(avik): use sympy value range analysis instead?
        from sympy import Integer

        _min_symint = self.fn(Integer(self.root.min))  # type: ignore[attr-defined]
        assert _min_symint >= 2, (
            f"Expected derived min value of {self.__name__} to be >= 2. "
            f"Please specify an appropriate min value for {self.root.__name__} "  # type: ignore[attr-defined]
            f"(currently {self.root.min})."  # type: ignore[attr-defined]
        )
        return int(_min_symint)

    @property
    def max(self):
        # assume that self.fn is an increasing function
        # TODO(avik): use sympy value range analysis instead?
        from sympy import Integer

        _max_symint = self.fn(Integer(self.root.max))  # type: ignore[attr-defined]
        assert _max_symint <= sys.maxsize - 1, (
            f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. "
            f"Please specify an appropriate max value for {self.root.__name__} "  # type: ignore[attr-defined]
            f"(currently {self.root.max})."  # type: ignore[attr-defined]
        )
        return int(_max_symint)

    def _derive(self, fn):
        # We support nesting, e.g., 2*dim + 1.
        # This is implemented by composing operations on the same root.
        # As a consequence, roots are always regular Dims (i.e., not derived Dims).
        return _DerivedDim(
            self._derived_name(fn),
            (int,),
            {"root": self.root, "fn": lambda x: fn(self.fn(x))},  # type: ignore[attr-defined]
        )


[docs]def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): """ :func:`Dim` constructs a type analogous to a named symbolic integer with a range. It can be used to describe multiple possible values of a dynamic tensor dimension. Note that different dynamic dimensions of the same tensor, or of different tensors, can be described by the same type. Args: name (str): Human-readable name for debugging. min (Optional[int]): Minimum possible value of given symbol (inclusive) max (Optional[int]): Maximum possible value of given symbol (inclusive) Returns: A type that can be used in dynamic shape specifications for tensors. """ _min = 2 if min is None else builtins.max(min, 2) _max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" dim = _Dim(name, (int,), {"min": _min, "max": _max}) dim.__module__ = getattr( inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" ) return dim
[docs]def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): """ Util to create multiple :func:`Dim` types. """ return tuple(Dim(name, min=min, max=max) for name in names)
@dataclasses.dataclass class _ConstraintTarget: """ This represents input tensor dimensions. Don't create this class directly; instead, use :func:`dynamic_dim`. """ w_tensor: Any # weakref to torch.Tensor # TODO: We don't need t_id; we can get it off of w_tensor t_id: int dim: int class _ConstraintFactory(type): """ Metaclass that ensures a private constructor for :class:`_Constraint` """ def __call__(cls, *args, **kwargs): raise TypeError( f"{cls.__module__}.{cls.__qualname__} has no public constructor. " f"Please use torch.export.dynamic_dim() to create one" ) def _create( cls, w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None ): return super().__call__( w_tensor, t_id, dim, constraint_range, shared, debug_name ) def _create_constraint( w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None ): return _Constraint._create( w_tensor, t_id, dim, constraint_range, shared, debug_name ) @dataclasses.dataclass class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): """ .. warning:: Do not construct :class:`_Constraint` directly, use :func:`dynamic_dim` instead. This represents constraints on input tensor dimensions, e.g., requiring them to be fully polymorphic or within some range. """ # NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>] constraint_range: "StrictMinMaxConstraint" # Represent that `constraint_range` is shared with another _ConstraintTarget, which # typically arises because of a specified equality with another dynamic dimension. shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None def _clone_with_range(self, lower=2, upper=math.inf): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, ) return _create_constraint( self.w_tensor, self.t_id, self.dim, constraint_range, self.shared, self.debug_name, ) def __ge__(self, lower): return self._clone_with_range(lower=lower) def __gt__(self, lower): return self._clone_with_range(lower=lower + 1) def __le__(self, upper): return self._clone_with_range(upper=upper) def __lt__(self, upper): return self._clone_with_range(upper=upper - 1) def __bool__(self): # NOTE(avik): We do not support compound expressions like a <= x <= b. # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), # and moreover, enforces that any overload of __bool__ must return True or False. # FWIW, sympy also raises TypeError in this case. raise TypeError( "Cannot determine truth value of _Constraint. " "If you are trying to combine _Constraint's with logical connectives, " "you can specify them separately instead." ) @property def serializable_spec(self): # We need a serialization compatible format of the constraint so that it # can be savedin the graph module w/o breaking the module serialization. # The saved constraints will be used directly for the post-exporting pass # that converts constraints to runtime assertion. The saved constraints # will not be saved in the serialized module. # TODO: A better way is needed. Currently we use 't_id' to map the constraint, # which is not reliable return { "t_id": self.t_id, "dim": self.dim, "min": self.constraint_range.vr.lower, "max": self.constraint_range.vr.upper, } def __eq__(self, other): if not isinstance(other, _Constraint): raise TypeError( "A dynamic dim can be specified equal only to another dynamic dim. " f"Equality with {type(other)} is not supported." ) # import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & other.constraint_range.vr, warn_only=False, ) if self.debug_name is None: debug_name = other.debug_name else: assert other.debug_name is None or self.debug_name == other.debug_name debug_name = self.debug_name return _create_constraint( self.w_tensor, self.t_id, self.dim, constraint_range, shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim), debug_name=debug_name, ) @dataclasses.dataclass class _PhantomRoot: """ This represents the root of a derived Dim where the root does not directly specify the shape of any input dimension, but the derived Dim does. e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. The fields `name`, `constraint_range`, and `val` carried by a phantom root help create a symbol for it. Any derived dims with this phantom root are backed by expressions over this symbol. """ name: str constraint_range: "StrictMinMaxConstraint" val: int @dataclasses.dataclass class _DerivedConstraint(_ConstraintTarget): """ This represents a derived Dim, whose root is either a regular constraint target (which directly specifies the shape of some input dimension) or a phantom root (which does so indirectly). """ # NOTE: This is not currently a subclass of _Constraint because we do not support # `shared` for derived `Dim`s. Indeed, sharing is a necessary concept only for # legacy constraints based on `dynamic_dim`: equality can be expressed simply by # reusing the same (derived or normal) `Dim`. root: Union[_ConstraintTarget, _PhantomRoot] fn: Callable constraint_range: "StrictMinMaxConstraint" debug_name: Optional[str] = None @property def shared(self): # Some code paths expect a union of _Constraint and _DerivedConstraint. # Thus we expose a `shared` field that is always None. # TODO(avik): clean this up return None @property def serializable_spec(self): # same as _Constraint.serializable_spec return { "t_id": self.t_id, "dim": self.dim, "min": self.constraint_range.vr.lower, "max": self.constraint_range.vr.upper, } Constraint = Union[_Constraint, _DerivedConstraint]
[docs]def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): """ .. warning:: (This feature is DEPRECATED. See :func:`Dim` instead.) :func:`dynamic_dim` constructs a :class:`_Constraint` object that describes the dynamism of a dimension ``index`` of tensor ``t``. :class:`_Constraint` objects should be passed to ``constraints`` argument of :func:`export`. Args: t (torch.Tensor): Example input tensor that have dynamic dimension size(s) index (int): Index of dynamic dimension Returns: A :class:`_Constraint` object that describes shape dynamism. It can be passed to :func:`export` so that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic as a symbolic size rather than specializing according to size of example tracing input. Specifically :func:`dynamic_dim` can be used to express following types of dynamism. - Size of a dimension is dynamic and unbounded:: t0 = torch.rand(2, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size rather than always being static size 2 constraints = [dynamic_dim(t0, 0)] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with a lower bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) constraints = [ dynamic_dim(t0, 0) >= 5, dynamic_dim(t1, 1) > 2, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with an upper bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) constraints = [ dynamic_dim(t0, 0) <= 16, dynamic_dim(t1, 1) < 8, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # Sizes of second dimension of t0 and first dimension are always equal constraints = [ dynamic_dim(t0, 1) == dynamic_dim(t1, 0), ] ep = export(fn, (t0, t1), constraints=constraints) - Mix and match all types above as long as they do not express conflicting requirements """ from torch._dynamo.exc import UserError, UserErrorType if not isinstance(t, torch.Tensor): raise UserError( UserErrorType.DYNAMIC_DIM, f"Expected tensor as input to dynamic_dim but got {type(t)}", ) if t.dim() < 1: raise UserError( UserErrorType.DYNAMIC_DIM, "Cannot mark 0-dimension tensors to be dynamic" ) if index >= t.dim(): raise UserError( UserErrorType.DYNAMIC_DIM, f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]" f" but got {index}, which is out of bounds for the given tensor.", ) # Import sympy locally import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges return _create_constraint( weakref.ref(t), id(t), index, StrictMinMaxConstraint( vr=ValueRanges(lower=2, upper=sympy.oo), warn_only=False ), debug_name=debug_name, )
def _process_equalities( constraint: Constraint, get_sources: Callable[[int, int], List["Source"]], shape_env: "ShapeEnv", source_pairs: List[Tuple["Source", "Source"]], derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], phantom_symbols: Dict[str, "Symbol"], ): """ Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become fields of `EqualityConstraint`) based on a given input `constraint`. """ source, *other_sources = get_sources(constraint.t_id, constraint.dim) # When t.size()[dim] maps to src0, src1, ..., srcN, we add # constraints that make src0 "equal" to src1, ..., srcN. source_pairs.extend((source, other_source) for other_source in other_sources) if not isinstance(constraint, _DerivedConstraint): if constraint.shared is not None: # Moreover, when t.size()[dim] is specified equal to t'.size()[dim'] # and t'.size()[dim'] maps to src1', ..., srcN', we add # constraints that also make src0 "equal" to src1', ..., srcN'. other_sources = get_sources(constraint.shared.t_id, constraint.shared.dim) source_pairs.extend( (source, other_source) for other_source in other_sources ) else: # branch based on the root of the _DerivedConstraint if not isinstance(constraint.root, _PhantomRoot): # either root points to an input source root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] else: # or root points to a phantom symbol if constraint.root.name in phantom_symbols: root = phantom_symbols[constraint.root.name] # type: ignore[assignment] else: # create a phantom symbol in the shape env based on the _PhantomRoot root = shape_env.create_symbol( val=constraint.root.val, source=torch._dynamo.source.ConstantSource(constraint.root.name), dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, constraint_dim=constraint.root.constraint_range, ) phantom_symbols[constraint.root.name] = root # type: ignore[assignment] fn = constraint.fn # A derived equality (source, root, fn) informally corresponds to source = fn(root). # Here source describes an input and root might describe another input or a phantom symbol. derived_equalities.append((source, root, fn)) def _process_dynamic_shapes( f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, ) -> Optional[List[Constraint]]: from collections import defaultdict from collections.abc import Mapping, Sequence from torch._dynamo.exc import UserError, UserErrorType if dynamic_shapes is None or len(dynamic_shapes) == 0: return None kwargs = kwargs if kwargs is not None else {} def tree_zip(combined_args, dynamic_shapes): if isinstance(combined_args, (tuple, list)): if not isinstance(dynamic_shapes, Sequence): raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, " f"got {dynamic_shapes} instead", ) if len(combined_args) != len(dynamic_shapes): raise UserError( UserErrorType.INVALID_INPUT, f"Expected {dynamic_shapes} to have {len(combined_args)} items", ) for i, shape in enumerate(dynamic_shapes): yield from tree_zip(combined_args[i], shape) elif isinstance(combined_args, dict): if not isinstance(dynamic_shapes, Mapping): raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, " f"got {dynamic_shapes} instead", ) if len(combined_args) != len(dynamic_shapes): raise UserError( UserErrorType.INVALID_INPUT, f"Expected {dynamic_shapes} to have {len(combined_args)} items", ) for k, shape in dynamic_shapes.items(): yield from tree_zip(combined_args[k], shape) elif type(combined_args) in SUPPORTED_NODES: if not isinstance(dynamic_shapes, Sequence): raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a user-registered class (e.g., " f"{type(combined_args)}) to be a Sequence that matches the " f"flattened structure, but got {dynamic_shapes} instead", ) yield from tree_zip( SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0], dynamic_shapes, ) elif isinstance(combined_args, torch.Tensor): yield (combined_args, dynamic_shapes) else: if dynamic_shapes is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a {type(combined_args)} to be None, " f"got {dynamic_shapes} instead", ) # map of Dim names representing input shape dimensions to constraints on them symbols: Dict[str, List[Constraint]] = defaultdict(list) # track roots that do not directly represent input shape dimensions phantom_roots: Dict[str, _PhantomRoot] = {} derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] def to_constraint(dim, tensor, i): import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import ValueRanges def root_value(): # given tensor.shape[i] is the value of dim = fn(root), # find the value of root symbol = sympy.Symbol(dim.root.__name__, integer=True) expr = dim.fn(symbol) solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) if solution is not None: return int(solution[1]) # type: ignore[call-overload] else: raise UserError( # noqa: TRY200 UserErrorType.CONSTRAINT_VIOLATION, f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " f"of the form {expr}, where {symbol} is an integer", ) if isinstance(dim, _DerivedDim): # generate a _DerivedConstraint where the root is: # - either a _ConstraintTarget (if dim.root directly describes an input shape) # - or a _PhantomRoot (otherwise) dim_root = dim.root # type: ignore[attr-defined] if dim_root.__name__ in symbols: # root represents an input shape dimension root_constraint = symbols[dim_root.__name__][0] root = _ConstraintTarget( root_constraint.w_tensor, root_constraint.t_id, root_constraint.dim, ) elif dim_root.__name__ not in phantom_roots: # create a phantom root root = _PhantomRoot( # type: ignore[assignment] name=dim_root.__name__, constraint_range=StrictMinMaxConstraint( vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), warn_only=False, ), val=root_value(), ) phantom_roots[dim_root.__name__] = root # type: ignore[assignment] else: root = phantom_roots[dim_root.__name__] # type: ignore[assignment] constraint = _DerivedConstraint( weakref.ref(tensor), id(tensor), i, root, dim.fn, # type: ignore[attr-defined] StrictMinMaxConstraint( vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False, ), debug_name=dim.__name__, ) if isinstance(root, _PhantomRoot): # NOTE(avik): since we have not processed all inputs yet, we may replace this # with a root that does represent an input shape dimension later (see below) derived_constraints_with_phantom_root.append(constraint) else: constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) if dim.min != 2: constraint = constraint >= dim.min if dim.max != sys.maxsize - 1: constraint = constraint <= dim.max return constraint bounds: Dict[str, Tuple[int, int]] = {} def check_same_bounds(dim): if dim.__name__ in symbols: min_, max_ = bounds[dim.__name__] if dim.min != min_ or dim.max != max_: this_ = _Dim.readable(dim.__name__, min_, max_) that_ = _Dim.readable(dim.__name__, dim.min, dim.max) raise UserError( UserErrorType.INVALID_INPUT, f"Found different definitions {this_} and {that_} " f"for the same symbolic dimension {dim}!", ) else: bounds[dim.__name__] = (dim.min, dim.max) def update_symbols(tensor, shape): if isinstance(shape, dict): for i, dim in shape.items(): if isinstance(dim, _Dim): check_same_bounds(dim) constraint = to_constraint(dim, tensor, i) symbols[dim.__name__].append(constraint) else: if dim is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " "try None instead", ) elif isinstance(shape, (tuple, list)): for i, dim in enumerate(shape): if isinstance(dim, _Dim): check_same_bounds(dim) constraint = to_constraint(dim, tensor, i) symbols[dim.__name__].append(constraint) else: if dim is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " "try None instead", ) else: if shape is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead", ) import inspect if isinstance(f, ExportedProgram): f = f.module() signature = ( inspect.signature(f.forward) if isinstance(f, torch.nn.Module) else inspect.signature(f) ) combined_args = signature.bind(*args, **kwargs).arguments # This means user didn't specify dynamic shapes with argument names. combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment] for tensor, shape in tree_zip(combined_args, dynamic_shapes): update_symbols(tensor, shape) constraints = [] for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] if phantom_root_name in symbols: # We found an input shape dimension corresponding to this name, so we # do not need a phantom symbol for it after all. # NOTE(avik): Overall we want to maintain the invariant that roots that # are phantom symbols are really "phantom," i.e., they cannot be represented # by any input source. This is important when we are deciding derived equalities, # since we can focus our attention exclusively on input sources: deciding # derived equalities involving phantom symbols are, in comparison, trivial. derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] for dynamic_dims in symbols.values(): if all( isinstance(dynamic_dim, _DerivedConstraint) for dynamic_dim in dynamic_dims ): constraints.extend(dynamic_dims) else: primary, *others = dynamic_dims if others: for other in others: constraints.append(primary == other) # type: ignore[arg-type] else: constraints.append(primary) return constraints # type: ignore[return-value] def _process_constraints( fake_mode, graph_module: torch.fx.GraphModule, num_lifted_params_buffers: int, example_inputs: List[torch.Tensor], ) -> Dict: """ Process the constraints stored in the graph module to return something more readable. Args: graph_module (torch.fx.GraphModule): GraphModule returned from dynamo.export, which contains the "input_shape_constraints" and "inline_constraints" metadata example_inputs: Flattened list of example inputs used to export the graph module Returns: range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of symbols (from SymInts) appearing in the fake tensors in node.meta["val"] to their range constraints, which are a tuple containing (lower, upper) constraints. """ from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( InputDim, ) # Import sympy locally from torch.fx.experimental.symbolic_shapes import SymInt from torch.utils._sympy.value_ranges import ValueRanges input_shape_constraints = graph_module.meta.get("input_shape_constraints", []) inline_constraints = graph_module.meta.get("inline_constraints", []) # Create dict mapping tensor_id to node names tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list) # Create dict mapping placeholder node names to their nodes placeholder_nodes: Dict[str, torch.fx.Node] = {} for i, node in enumerate(graph_module.graph.nodes): if node.op != "placeholder": # All placeholder nodes should be together in the beginning of the # graph break if i >= num_lifted_params_buffers: example_input = example_inputs[i - num_lifted_params_buffers] tensor_id_to_nodes[id(example_input)].append(node.name) placeholder_nodes[node.name] = node # Create dict mapping (node name, dim) a list of range (lower, upper) # constraints multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list) for constraint in input_shape_constraints: for node in tensor_id_to_nodes[constraint["t_id"]]: node_dim = InputDim(node, constraint["dim"]) # Accumulate range constraints multi_range_constraints[node_dim].append( ValueRanges(constraint["min"], constraint["max"]) ) # Create dict mapping symbol to a singular range (lower, upper) range_constraints: Dict[Any, ValueRanges] = {} # Add inline constraints to range_constraints range_constraints = { symbol: inline_constraints[symbol] for symbol in inline_constraints } free_symbols: Set["Symbol"] = set() # Add input range constraints to range_constraints for input_dim, multi_range_constraint in multi_range_constraints.items(): # type: ignore[assignment] # Simplify the range constraints into a single range constraint # Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10] min_vals = [rc.lower for rc in multi_range_constraint] max_vals = [rc.upper for rc in multi_range_constraint] min_val = max(min_vals) # type: ignore[type-var] max_val = min(max_vals) # type: ignore[type-var] assert min_val <= max_val # type: ignore[operator] # Add input node range constraints val = placeholder_nodes[input_dim.input_name].meta["val"] assert isinstance(val, FakeTensor) symint = val.shape[input_dim.dim] assert isinstance( symint, SymInt ), f"Expected SymInt but got {symint}: {type(symint)}" symbol = symint.node.expr range_constraints[symbol] = ValueRanges(min_val, max_val) free_symbols.update(symbol.free_symbols) for symbol in free_symbols: if symbol not in range_constraints: # Placeholders can have symbolic shapes that are derived expressions. # The above code will record direct range constraints for them # so that we can do runtime assertions. In addition, for serde checks # we want to record range constraints for their root symbols. range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol] return range_constraints

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources