Source code for torch.fx.experimental.symbolic_shapes
# mypy: ignore-errors
"""
``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with
our symbolic shapes reasoning system that is used heavily in torch.compile. Although
this is not generally considered public API, when writing framework code in PyTorch
as well as extensions to PyTorch (e.g., in custom operator implementations), you may
need to make use of these APIs to setup dynamic shapes support appropriately.
"""
import builtins
import collections
import functools
import inspect
import itertools
import logging
import math
import operator
import re
import sys
import threading
import traceback
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
import atexit
from typing import (
Any,
cast,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
TYPE_CHECKING
)
from typing_extensions import TypeAlias
import torch
import torch.fx
import torch.fx.traceback as fx_traceback
from torch.fx.experimental import _config as config
from torch.fx.experimental.recording import (
FakeTensorMeta,
ShapeEnvEvent,
record_shapeenv_event,
replay_shape_env_events,
shape_env_check_state_equal
)
from torch.fx.experimental.sym_node import SymNode, SymTypes
from torch._logging import trace_structured, structured
# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import SymBool, SymFloat, SymInt
from torch._guards import ShapeGuard, Source, TracingContext
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.functions import (
FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt
)
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._traceback import format_frame, CapturedTraceback
from torch._utils_internal import signpost_event
from torch._subclasses.meta_utils import is_sparse_any
import torch.utils._pytree as pytree
from torch.utils._sympy.symbol import SymT, make_symbol, symbol_is_type
from torch._logging import LazyString
if TYPE_CHECKING:
from torch._dynamo.source import TensorPropertySource
InputList = List
DimList = List
log = logging.getLogger(__name__)
class GuardOnDataDependentSymNode(RuntimeError):
pass
class PendingUnbackedSymbolNotFound(RuntimeError):
pass
import sympy
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence, PRECEDENCE
aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
"is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
"guard_size_oblivious", "check_consistent",
"compute_unbacked_bindings", "ConvertIntKey",
"rebind_unbacked", "resolve_unbacked_bindings",
]
# FX node metadata keys for symbolic shape FX graph.
SHAPEENV_EVENT_KEY = "shapeenv_event"
CURRENT_NODE_KEY = "current_node"
def log_lru_cache_stats(wrapped_f):
log.debug("lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info())
# Wrapper on lru_cache that reports statistics at process end
[docs]def lru_cache(maxsize):
def inner(f):
wrapped_f = functools.lru_cache(maxsize)(f)
old_cache_clear = wrapped_f.cache_clear
prev_hits = 0
prev_misses = 0
# TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info
# -> wrapped_f) but cannot be solved with weakref as wrapped_f is not
# weakref'able on some versions of Python
def cumulative_cache_info():
cur = wrapped_f.cache_info()
return functools._CacheInfo(
prev_hits + cur.hits,
prev_misses + cur.misses,
cur.maxsize,
cur.currsize,
)
def new_cache_clear():
nonlocal prev_hits, prev_misses
cur = wrapped_f.cache_info()
prev_hits += cur.hits
prev_misses += cur.misses
old_cache_clear()
wrapped_f.cache_clear = new_cache_clear
wrapped_f.cumulative_cache_info = cumulative_cache_info
if log.isEnabledFor(logging.DEBUG):
atexit.register(log_lru_cache_stats, wrapped_f)
return wrapped_f
return inner
# These are modules that contain generic code for interacting with ShapeEnv
# which are unlikely to identify a particular interesting guard statement
@lru_cache(None)
def uninteresting_files() -> Set[str]:
import torch._inductor.sizevars
import torch._library.abstract_impl
import torch._subclasses.meta_utils
import torch._subclasses.fake_tensor
mods = [
sys.modules[__name__],
torch.fx.experimental.recording,
torch.fx.experimental.sym_node,
torch.fx.interpreter,
torch,
torch._inductor.sizevars,
torch._library.abstract_impl,
torch._subclasses.meta_utils,
torch._subclasses.fake_tensor,
]
return {inspect.getfile(m) for m in mods}
# We don't bother with the metaclass as all of the dispatching logic happens
# entirely from Python
#
# Didn't bother with ancestors for now, unlikely to have multiple modes for
# symints right now
class ConstraintViolationError(RuntimeError):
pass
def has_symbolic_sizes_strides(elem) -> bool:
return elem._has_symbolic_sizes_strides
Int = Union[torch.SymInt, int]
def create_contiguous(shape: Sequence[Int]) -> List[Int]:
strides: List[Int] = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
[docs]def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
"""
Retrieve the hint for an int (based on the underlying real values as observed
at runtime). If no hint is available (e.g., because data dependent shapes),
if fallback is not None, use that instead (otherwise raise an error).
"""
if isinstance(a, torch.SymInt):
return a.node.require_hint(fallback)
assert type(a) is int, a
return a
Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
def has_hint(a: Scalar) -> bool:
if isinstance(a, SymTypes):
return a.node.has_hint()
return True
[docs]def is_concrete_int(a: Union[int, SymInt]) -> bool:
r""" Utility to check if underlying object
in SymInt is concrete value. Also returns
true if integer is passed in.
Args:
a (SymInt or int): Object to test if it int
"""
assert isinstance(a, (SymInt, int))
if isinstance(a, int):
return True
if isinstance(a.node.expr, sympy.core.numbers.Integer):
return True
return False
# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime.
# So make sure only type checker evaluates this alias.
# Xref: https://www.internalfb.com/diff/D53324783
SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean"
[docs]def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
"""
Perform a guard on a symbolic boolean expression in a size oblivious way.
This is typically used when a non-oblivious test would result in a guard
on a data dependent value of which we don't know the value of at compile time.
When a guard is tested this way, we may diverge in behavior from how regular
PyTorch semantics would treat it. For more information, see
https://github.com/pytorch/pytorch/pull/118579
"""
if isinstance(expr, torch.SymBool):
return expr.node.guard_size_oblivious("", 0)
else:
assert isinstance(expr, bool)
return expr
[docs]def check_consistent(new, old) -> None:
"""
Test that two "meta" values (typically either Tensor or SymInt) have
the same values, e.g., after retracing. If we don't understand the
quantities in question, we'll just skip the consistency check.
"""
# TODO: do boolean equality test too, see
# https://github.com/pytorch/pytorch/issues/124110
scalar_types = (torch.SymInt, torch.SymFloat, int, float)
if isinstance(new, torch.Tensor):
assert isinstance(old, torch.Tensor)
torch._check(old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)")
# Do this manually so that each individual test is irrefutable
# (TODO: should be a helper for this, maybe sym_eq? That
# gives us a compound expression and I'm not sure it
# simplifies right now)
for i, j in zip(old.shape, new.shape):
torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)")
# NB: bool is subclass of int
elif isinstance(new, scalar_types) and not isinstance(new, bool):
assert isinstance(old, scalar_types) and not isinstance(old, bool), f"{old} != {new}"
torch._check(old == new, lambda: f"{old} != {new} (old != new)")
[docs]def resolve_unbacked_bindings(shape_env, bindings):
if bindings is None:
return None
return {
shape_env.unbacked_renamings.get(k, k): v
for k, v in bindings.items()
}
[docs]def rebind_unbacked(shape_env, n: torch.fx.Node, result):
"""
Suppose we are retracing a pre-existing FX graph that previously had
fake tensor propagation (and therefore unbacked SymInts). When we retrace,
we re-propagate fake tensors, which results in new unbacked SymInts.
When this happens, we need to tell the shape environment about the equivalence
of the old and new unbacked SymInts. Pass us the old torch.fx.Node (which
has the old binding information) and the new result (which we can extract the
new unbacked SymInts out from).
"""
from torch._dynamo.tensor_version_op import _tensor_version
# Inputs never need rebinding
if n.op == "placeholder":
return
if bindings := resolve_unbacked_bindings(shape_env, n.meta.get("unbacked_bindings")):
for raw_u0, path in bindings.items():
u1 = pytree.key_get(result, path)
# tensor_version ops get specialized after AOTAutograd, it's OK,
# we don't actually want to do asserts on them. This is all a bit
# questionable though
if isinstance(u1, int) and n.target is _tensor_version:
log.info("rebind_unbacked: discard _tensor_version %s %s -> %s", raw_u0, path, u1)
continue
raw_u1 = u1.node.expr
# Simplify SymBool binding
if (
isinstance(raw_u1, sympy.Piecewise) and
len(raw_u1.args) == 2 and
raw_u1.args[0][0] == 1 and
isinstance(eq := raw_u1.args[0][1], sympy.Eq) and
isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and
shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and
eq.rhs == 1 and
raw_u1.args[1] == (0, True)
):
# This is what the pattern match above is testing
repacked = _sympy_cast_symbool_to_symint_guardless(sympy.Eq(new_raw_u1, 1))
assert repacked == raw_u1, f"{repacked} != {raw_u1}"
# Cancel the to_int(to_bool(x)). This is sound because x in
# [0, 1]
raw_u1 = new_raw_u1
assert isinstance(raw_u1, sympy.Symbol)
# The old and new could be the same if you improperly hit the memo
# while retracing. Make sure you updated FakeTensorMode.epoch
assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster"
# Reuse the OLD symbol name
shape_env._rename_unbacked_to(raw_u1, raw_u0)
[docs]def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
r""" Canonicalize a boolean expression by transforming it into a lt / le
inequality and moving all the non-constant terms to the rhs.
We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
recursively
nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924
Args:
expr (sympy.Expr): Expression to canonicalize
"""
# Canonicalise an inequality by transforming it into a lt / le
# inequality and moving all the non-constant terms to the rhs
# We canonicalise And / Ors / Not via cnf
# nb. Relational.canonical in sympy is broken
# https://github.com/sympy/sympy/issues/25924
if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)):
return expr
if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)):
expr = sympy.logic.boolalg.to_cnf(expr)
return _canonicalize_bool_expr_impl(expr)
def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
"""
After canonicalization, we are guaranteed to have eliminated Ge/Gt relations
(rewriting them to Le/Lt, respectively).
"""
if isinstance(expr, (sympy.And, sympy.Or)):
return type(expr)(*map(canonicalize_bool_expr, expr.args))
opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
if isinstance(expr, tuple(opposite.keys())):
lhs = expr.rhs - expr.lhs
t = opposite[type(expr)]
else:
assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne))
lhs = expr.lhs - expr.rhs
t = type(expr)
rhs = 0
if isinstance(lhs, sympy.Add):
cts = []
variables = []
for term in lhs.args:
if term.is_number:
cts.append(term)
else:
variables.append(term)
lhs = sympy.Add(*variables)
rhs = -sympy.Add(*cts)
return t(lhs, rhs)
[docs]def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
r""" Utility to check if underlying object
in SymBool is concrete value. Also returns
true if integer is passed in.
Args:
a (SymBool or bool): Object to test if it bool
"""
assert isinstance(a, (SymBool, bool))
if isinstance(a, bool):
return True
if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)):
return True
return False
def is_nested_int(s):
return isinstance(s, torch.SymInt) and s.node.is_nested_int()
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
if isinstance(val, SymTypes):
# This allow applies to the jagged layout NestedTensor case as
# nested ints are not symbolic
if is_symbolic(val):
yield val.node.expr
elif isinstance(val, sympy.Basic):
yield val
elif isinstance(val, (int, float, bool)):
pass
elif isinstance(val, (tuple, list)):
for s in val:
yield from _iterate_exprs(s)
elif is_sparse_any(val):
yield from _iterate_exprs(val.size())
elif isinstance(val, torch.Tensor):
yield from _iterate_exprs(val.size())
yield from _iterate_exprs(val.stride())
yield from _iterate_exprs(val.storage_offset())
elif val is None:
pass
else:
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
def free_symbols(val: Union[SymInt, sympy.Expr, torch.Tensor]) -> Set[sympy.Symbol]:
if val is None:
return set()
itr = _iterate_exprs(val)
# we need at least 1 to call union, so we hand code the identity
try:
first_expr = next(itr)
except StopIteration:
return set()
return first_expr.free_symbols.union(*(e.free_symbols for e in itr))
[docs]def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool:
"""Faster version of bool(free_symbols(val))"""
return not all(e.is_number for e in _iterate_exprs(val))
# Like free_symbols, but filtered to only report unbacked symbols
def free_unbacked_symbols(x):
# NB: keep synced with is_unbacked_symint
return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))}
# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
# setup!
def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
if (
node.op == "placeholder" and
"val" in node.meta and
isinstance(node.meta["val"], torch.SymInt) and
isinstance(node.meta["val"].node.expr, sympy.Symbol)
):
return node.meta["val"].node.expr
return None
def find_symbol_binding_fx_nodes(graph):
return {
node.meta["val"].node.expr: node
for node in graph.nodes
if is_symbol_binding_fx_node(node)
}
# Analogous to ConvertIntSource
[docs]@dataclass(frozen=True)
class ConvertIntKey:
def __str__(self) -> str:
return ".cast_symbool_to_symint_guardless()"
[docs] def get(self, b: bool) -> int:
"""Get the int value from bool"""
return cast_symbool_to_symint_guardless(b)
[docs]@dataclass(frozen=True)
class CallMethodKey:
name: str
def __str__(self) -> str:
return f".{self.name}()"
[docs]@dataclass(frozen=True)
class InnerTensorKey:
inner_name: str
def __str__(self) -> str:
return f".{self.inner_name}"
[docs] def get(self, o: Any) -> Any:
"""Get the inner tensor attribute"""
return getattr(o, self.inner_name)
[docs]@dataclass(frozen=True)
class DivideByKey:
divisor: int
def __str__(self) -> str:
return f".__floordiv__({self.divisor})"
[docs]def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, peek=False):
"""
After having run fake tensor propagation and producing example_value
result, traverse example_value looking for freshly bound unbacked
symbols and record their paths for later. It is an error if
we have allocated an unbacked SymInt but it cannot be found in
example_value. (NB: this means if you have a multi-output
function, you must call this on the tuple of tensor output, you
cannot wait!)
The peek parameter lets you check out what the bindings are without
changing the affected list. This is primarily useful for ensuring
unbacked_var_to_val is promptly populated when propagate_real_tensors is on.
"""
if shape_env is None:
return
if shape_env._ignore_fresh_unbacked_symbols_tls():
return
fs = shape_env.pending_fresh_unbacked_symbols
pending = set(fs)
if pending:
if not peek:
log.info("compute_unbacked_bindings %s", fs)
fs.clear()
def free_unbacked_symbols_with_path(
a, path, real=None
) -> Dict[sympy.Symbol, pytree.KeyPath]:
r = {}
if isinstance(a, (tuple, list)):
for i in range(len(a)):
r.update(
free_unbacked_symbols_with_path(
a[i], path + (pytree.SequenceKey(i),),
real=real[i] if real is not None else None
)
)
elif is_traceable_wrapper_subclass(a):
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(
free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),))
)
elif isinstance(a, torch.Tensor):
r.update(
free_unbacked_symbols_with_path(
a.size(), path + (CallMethodKey("size"),),
real=a.real_tensor.size() if a.real_tensor is not None else None
)
)
r.update(
free_unbacked_symbols_with_path(
a.stride(), path + (CallMethodKey("stride"),),
real=a.real_tensor.stride() if a.real_tensor is not None else None
)
)
r.update(
free_unbacked_symbols_with_path(
a.storage_offset(), path + (CallMethodKey("storage_offset"),),
real=a.real_tensor.storage_offset() if a.real_tensor is not None else None
)
)
# NB: Intentionally access _expr, not expr, do not want
# simplification!
elif (
isinstance(a, (torch.SymInt, torch.SymFloat))
and isinstance(s := a.node._expr, sympy.Symbol)
and s in pending
):
r[s] = path
if real is not None:
shape_env.set_unbacked_var_to_val(s, real)
pending.remove(s)
# When an unbacked SymInt is perfectly divisible by an integer
# constant, we replace it with the integer constant to improve
# reasoning capabilities. However, in synthetic examples, it is
# then possible that the factor never is explicitly allocated.
# Fortunately, we can compute it by division.
elif (
isinstance(a, torch.SymInt)
and isinstance(s := a.node._expr, sympy.Mul)
and len(s.args) == 2
and isinstance(lhs := s.args[0], sympy.Integer)
and isinstance(rhs := s.args[1], sympy.Symbol)
and rhs in pending
):
# TODO: DivideByKey needs to test divisibility at runtime!
r[s] = path + (DivideByKey(int(lhs)),)
if real is not None:
shape_env.set_unbacked_var_to_val(s, real // int(lhs))
pending.remove(rhs)
# The annoyance here arises from the fact that SymBool is
# allocated by allocating a SymInt and then testing if it's equal
# to one. So you have a complicated binding site logic for this.
elif (
isinstance(a, torch.SymBool)
and isinstance(s := a.node._expr, sympy.Eq)
# This must match create_unbacked_symbool EXACTLY
and isinstance(s.lhs, sympy.Symbol)
and s.rhs == 1
and s.lhs in pending
):
r[s.lhs] = path + (ConvertIntKey(),)
if real is not None:
shape_env.set_unbacked_var_to_val(s, int(real))
pending.remove(s.lhs)
return r
symbol_to_path = free_unbacked_symbols_with_path(example_value, ())
if not peek and pending:
extra = (
repr((example_value.stride(), example_value.storage_offset()))
if isinstance(example_value, torch.Tensor)
else ""
)
raise PendingUnbackedSymbolNotFound(
f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n"
"Did you accidentally call new_dynamic_size() or item() more times "
"than you needed to in your fake implementation?\n"
"For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit"
)
# Why do we have to do some rebinding here? If the original FX node
# wasn't a binding site because you had a memo hit, but post
# translation you aren't a memo hit anymore, there's now a new binding
# site... but we know (because it's the same FX node) that the value
# is actually the same, they're just not obviously equal anymore.
#
# The logic here is written carefully, because unlike the
# bind_unbacked case, we are not guaranteed to have a symbol for
# old_sym. If we have a symbol, do regular rename unbacked to; but if
# we don't, we need to specially eliminate the fresh unbacked symbol
# (NB: we are /trusting/ that the memoization is correct, and that we
# don't need to generate a new runtime assert. This is load bearing,
# as repropagation can happen after we've frozen runtime asserts.)
if old_example_value is not None:
for keypath in symbol_to_path.values():
old_sym = pytree.key_get(old_example_value, keypath)
new_sym = pytree.key_get(example_value, keypath)
if (
isinstance(new_sym, SymTypes) and
isinstance(new_s := new_sym.node.expr, sympy.Symbol)
):
if isinstance(old_sym, SymTypes) and (old_s := old_sym.node.expr) != new_s:
if isinstance(old_s, sympy.Symbol):
shape_env._rename_unbacked_to(new_s, old_s)
else:
shape_env._eliminate_unbacked(new_s, old_s)
elif not isinstance(old_sym, SymTypes):
shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
return symbol_to_path
[docs]def definitely_true(a):
"""
Returns True only if we can tell that a is True, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression to return True.
When is it appropriate to use definitely_true? First, if you can use
a higher level combinator like parallel_or/parallel_and, prefer using
those instead, they are definitely safe (modulo short-circuiting).
Second, it can be used if the program would behave equivalently if
definitely_true always returned False (parallel_or/parallel_and are
examples of this pattern, modulo short-circuiting). Finally, it even
be OK if the program wouldn't behave equivalently, so long as the
change is semantics preserving. It can be semantics preserving if
the program errors in more cases than it did previously (but otherwise
behaves identically), or if it changes some quantity in a way that
doesn't matter (e.g., strides often fall in this bucket.)
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return guard_bool(a)
else:
return False
return bool(a)
[docs]def definitely_false(a):
"""
Returns True only if we can tell that a is False, possibly introducing
a guard in the process. If a depends on some unbacked SymInt, we may
return False even though there may exist a possible value of the SymInt
that would cause the expression a to be False. See definitely_true
for more usage guidance.
"""
if isinstance(a, SymBool):
if a.node.has_hint():
return not guard_bool(a)
else:
return False
return not bool(a)
[docs]def statically_known_true(x: Union[bool, SymBool]) -> bool:
"""Returns True if x can be simplified to a constant and is true.
.. note::
This function doesn't introduce new guards, so the expression may end
up evaluating to true at runtime even if this function returns False.
Args:
x (bool, SymBool): The expression to try statically evaluating
"""
if isinstance(x, SymBool):
expr = x.node.expr
shape_env = x.node.shape_env
try:
simplified = shape_env._maybe_evaluate_static(expr)
if simplified is not None:
return bool(simplified)
except Exception:
log.debug("Could not simplify %s", expr)
return False
assert isinstance(x, bool)
return x
[docs]def parallel_or(*args):
"""
Evaluate the logical OR of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely True.
"""
if any(statically_known_true(a) for a in args):
return True
if any(definitely_true(a) for a in args):
return True
return any(args)
[docs]def parallel_and(*args):
"""
Evaluate the logical FALSE of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely False.
"""
if any(statically_known_true(torch.sym_not(a)) for a in args):
return False
if any(definitely_false(a) for a in args):
return False
return all(args)
[docs]def sym_eq(x, y):
"""
Like ==, but when run on list/tuple, it will recursively test equality
and use sym_and to join the results together, without guarding.
"""
if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)):
if len(x) != len(y):
return False
return functools.reduce(operator.and_, map(sym_eq, x, y), True)
elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
return x == y
else:
raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
def guard_scalar(a):
if isinstance(a, (SymBool, bool)):
return guard_bool(a)
elif isinstance(a, (SymInt, int)):
return guard_int(a)
elif isinstance(a, (SymFloat, float)):
return guard_float(a)
else:
raise AssertionError(f"unrecognized scalar {a}")
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
shape_env.constrain_symbol_range(s, compiler_min, compiler_max)
def _advise_is_size(a):
"""
Don't use this directly; use torch._check_is_size instead.
This is a softer version of _constrain_range_for_size (with min=0,
max=Inf). Instead of forcibly constraining a variable (and erroring if we
failed to constrain it), it will simply advise us that a size is
constrained in some way. We will always defer a runtime assert for this
constraint if we cannot prove it at compile-time, but we we only
*sometimes* learn useful extra information at compile-time with this
information. This is in contrast to constrain_range_for_size, where if
you don't call that on a fresh unbacked symint, chances are we will choke.
TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed
code. Right now this is only really used in code with AOTAutograd trace
through, so it is not a big problem that this isn't supported, but in
principle all of this code should be Dynamo'able too.
TODO: I didn't support min/max because I didn't have a use case where this
actually helped. In principle we can support it, it just makes the
implementation below more complicated.
"""
# This must always succeed, because the sole allowed caller _check_is_size
# was responsible for expect_true'ing this
# This assert triggers expensive sym compute, do not do it until its cheap.
# assert a >= 0
# NB: it's important not to constrain range for size for *hinted* SymInts,
# because it is not only unsound, it will immediately trip our asserts
# that hints have to be consistent with static analysis! If you somehow
# have an unbounded SymInt that later constrains to 1, this will be
# inconsistent with the range
if (
isinstance(a, SymInt)
and isinstance(a.node, SymNode)
and isinstance(a.node.expr, sympy.Symbol)
and a.node.shape_env.is_unbacked_symint(a.node.expr)
):
_constrain_range_for_size(a)
def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
"""
This function is NOT INTENDED to be used by itself.
"""
if isinstance(a, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat/SymBool is nyi")
assert isinstance(a, SymInt), "can only constrain range for SymInt"
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
a.node.shape_env._constrain_range_for_size(a.node.expr, min, max)
# inclusive both ways
[docs]def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
"""
Applies a constraint that the passed in SymInt must lie between min-max
inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning
that it can be used on unbacked SymInts). If min/max are None, we assume
that the dimension is unbounded in that direction. Repeated application
of constrain_range intersects the ranges. This is a fairly low level API
that doesn't have a lot of safety guarantees (TODO: provide higher level
APIs).
Currently, we use this API in the following circumstance: when we allocate
an unbacked SymInt, denoting an integer quantity which is data dependent,
we ordinarily do not know anything about what values it may take. This
means that any sort of guard on it will immediately fail. However, in
many cases, we know something about the unbacked SymInt: for example, we
know that nonzero(x).size(0) must be >= 0. We use constrain_range to
narrow the possible range, declaring that negative symbols are impossible.
This permits to definitely answer True to queries like 'nnz >= 0', even if
we don't know what the actual (hinted) value of 'nnz' is. In fact, we
actually use constrain_range to unsoundly discharge common guards: for an
unbacked SymInt produced by nonzero, we will also assume that it is not
equal to 0/1 (even though these are perfectly possible values at runtime),
because we generally expect graphs that are valid for N=2 to also be valid
for N=1.
"""
if min is None:
min = -sys.maxsize - 1
if max is None:
max = sys.maxsize - 1
if max < min:
raise ValueError(
"Maximum value to constrain_as_size can't be less than the specified min value, "
"received min={min} and max={max}"
)
if isinstance(a, int):
if not (min <= a <= max):
raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
return
a.node.shape_env._constrain_range(a.node.expr, min, max)
[docs]def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None:
"""
Given two SymInts, constrain them so that they must be equal. NB:
this will not work with SymInts that represent nontrivial expressions
(yet!)
"""
if not isinstance(a, SymInt):
if not isinstance(b, SymInt):
assert a == b
return
else:
shape_env = b.node.shape_env
else:
shape_env = a.node.shape_env
shape_env._constrain_unify(a, b)
# Assume that a boolean is true for the purposes of subsequent symbolic
# reasoning. This will keep track of corresponding runtime checks to verify
# that the result is upheld: either as a regular guard, or as a special set
# of asserts which are triggered when an unbacked SymInt is allocated.
#
# DO NOT use this function for these cases:
#
# - This is inappropriate for "branching" conditions (where both
# true and false result in valid programs). We will always assume
# the condition evaluates true, and so it will never be possible
# to trace the false condition when you use it. For true branching
# on unbacked SymInts, you must use torch.cond; if you incorrectly
# use expect_true in this case, you will make the false branch
# unreachable (as we will simply assume that only the true branch
# is ever exercised).
#
# - This is inappropriate for situations where you know some other system
# invariant guarantees that this property holds, since you don't
# really need to insert a runtime check in that case. Use something
# like constrain_range in that case.
#
# This API has a hitch. To avoid having to reimplement error reporting
# capabilities, this function CAN return False. The invariant is that
# the surrounding code must raise an error when this function returns
# False. This is quite low level, so we recommend using other functions
# like check() which enforce this in a more intuitive way.
#
# By the way, this name is a nod to the __builtin_expect macro,
# which is used similarly (but unlike __builtin_expect, you MUST fail
# in the unlikely branch.) (I think expect is a good name; in recent
# versions of C++, this is replaced with [[likely]], which is weaker
# and not accurate for this function!)
def expect_true(a, skip: int = 0):
if isinstance(a, SymBool):
# TODO: check perf implications of this
frame = inspect.currentframe()
for _ in range(skip + 1): # always run this loop at least once
frame = frame.f_back
return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno)
assert type(a) is bool, a
return a
def guard_bool(a):
if isinstance(a, SymBool):
return a.node.guard_bool("", 0) # NB: uses Python backtrace
assert type(a) is bool, a
return a
def guard_int(a):
if isinstance(a, SymInt):
return a.node.guard_int("", 0) # NB: uses Python backtrace
assert type(a) is int, a
return a
def guard_float(a):
if isinstance(a, SymFloat):
return a.node.guard_float("", 0) # NB: uses Python backtrace
assert isinstance(a, float), a
return a
# Given a GraphModule, return all the FakeTensors for all the placeholders
def fx_placeholder_vals(gm):
return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"]
def fx_placeholder_targets(gm):
return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
# Given a GraphModule and arguments to run it with, evaluate that the guards
# for its associated ShapeEnv are satisfied by the passed arguments. This
# WILL check for duck sizing.
def eval_guards(gm, *args, ignore_static=True):
return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static)
def bind_symbols(gm, *args):
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
[docs]class DimDynamic(Enum):
"""
Controls how to perform symbol allocation for a dimension. It is always
sound to default this to DYNAMIC, but the policies DUCK and STATIC can
result in better trace-time and compile-time performance, as they reduce
the number of allocated symbols and generally make your graph more static.
NB: If we notice you've applied a constraint to the dimension, we will
force it to DYNAMIC for simplicity.
DimDynamic is controlled by a variety of higher level UX features.
Currently:
- In eager mode, the default policy is DUCK.
- The default is changed to STATIC with assume_static_by_default.
- An individual dim is marked DYNAMIC if you mark_dynamic_dim.
- In export mode, the default policy is STATIC.
- An individual dim is marked DYNAMIC if you mention it as dynamic_dim
in the constraints kwarg.
"""
# Treat the dimension symbolically
DYNAMIC = 0
# Treat the dimension symbolically, but if its hint matches another
# dynamic dimension, unify the two symbols ("duck sizing")
DUCK = 1
# Treat the dimension statically based on its hint
STATIC = 2
# NB: These constraints affect both clients and backends: given some
# constraint C, the client must pass inputs that satisfy the constraint,
# while a backend must not introduce guards BEYOND this constraint.
# For clarity, we document the implications on both sides for both the client
# and the backend.
#
# NB: These constraints are on a *single* dimension. In principle, we could
# also have multi-dimension constraints, but our guess is that this is not
# actually useful and so we are not supporting it right now.
#
# NB: Strict constraints are typically only suitable for export, as in eager
# a backend like inductor may validly introduce extra, discretionary guards
# to improve performance of code. A StrictMinMaxConstraint would be brittle
# under future optimizations performed by inductor; we don't guarantee
# eager code with StrictMinMaxConstraint will keep working in the future!
@dataclass(frozen=True)
class Constraint:
warn_only: bool
[docs]@dataclass(frozen=True)
class StrictMinMaxConstraint(Constraint):
"""
For clients: the size at this dimension must be within 'vr' (which
specifies a lower and upper bound, inclusive-inclusive) AND it
must be non-negative and should not be 0 or 1 (but see NB below).
For backends: there must not be any guards on this dimension which
are not implied by the given lower and upper bound. Regardless of
the lower bound, the backend can assume the size is non-negative
and that it is not 0 or 1.
An unbounded StrictMinMaxConstraint can be thought of as a strict version
of "RelaxedUnspecConstraint".
NB: Export will often unsoundly assume that a graph works for 0/1, even
though at trace time we assumed size is not 0 or 1. The idea is that
if we produce a graph that works for a range of values, it will be OK
for N=0/1 too.
"""
vr: ValueRanges
[docs] def render(self, source: Source):
"""Format the constrain equation"""
# TODO: better printing for -oo and oo
return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}"
[docs]@dataclass(frozen=True)
class RelaxedUnspecConstraint(Constraint):
"""
For clients: no explicit constraint; constraint is whatever is implicitly
inferred by guards from tracing.
For backends: there must exist at least TWO possible values for the
size at this dimension which satisfy the guards for this dimension.
In other words, this constraint helps us distinguish between "we don't
care if this dimension specializes or not" versus "this dimension must be
unspecialized." However, this constraint doesn't say very much about what
specialization is permitted; for example, if we guard on a size being
even, this would still be acceptable under an unspec constraint. This
makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler
may add constraints to otherwise dynamic dimensions; we can't assert that
there are NO guards as this is brittle because compilers should be able to
add extra constraints. If you want to assert that there are no guards,
use StrictMinMaxConstraint with an unbounded ValueRanges.
"""
def render(self, source: Source):
return f"RelaxedUnspecConstraint({source.name()})"
# NB: None here indicates the client constraint is whatever is implicitly
# inferred by guards from tracing, and that a backend can add whatever guards
# it wants (including fully specializing the value).
DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None]
[docs]@dataclass(frozen=True)
class EqualityConstraint(Constraint):
"""
Represent and decide various kinds of equality constraints between input sources.
A "source pair" is a pair of input sources for dynamic dimensions that
are specified equal. We represent `source_pairs` in a union-find forest
so that we can efficiently check whether two such sources are transitively equal.
A "derived equality" relates an input source to an expression over a root.
The root can be another input source, corresponding to some dynamic dimension,
or a phantom symbol that does not directly represent any dynamic dimension. We
represent `derived_equalities` involving input sources in a transitively-closed map
so that we can efficiently check whether an input source is transitively equal to
a given expression over another input source.
(NOTE: In contrast, it is easy to decide whether an input source is transitively equal
to a given expression over a phantom symbol; such expressions are already in canonical
form and so the problem reduces to symbolic expression equality.)
"""
source_pairs: List[Tuple[Source, Source]]
derived_equalities: List[Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]]
phantom_symbols: List[sympy.Symbol]
def __post_init__(self):
"""Pre-processing to answer queries `is_equal` and `is_derived` below.
Example: Suppose we are given:
source_pairs [a = b, b = c]
derived_equalities [d = c + 1, e = d - 1]
We first construct a union find with source_pairs:
_parents = {a: a, b: a, c: a}
Then we compute canonical symbolic expressions, recursively applying derived_equalities
until we bottom out:
_defs = {d: c + 1, e: (c + 1) - 1 aka c}
"""
# self._parents is a map from input sources to input sources where, conceptually,
# these are directed edges in a union-find forest
_parents: Dict[Source, Source] = {}
object.__setattr__(self, "_parents", _parents)
# self._defs is a map from input sources to "canonical" symbolic expressions,
# i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
# not derived Dims)
_defs: Dict[Source, sympy.Expr] = {}
object.__setattr__(self, "_defs", _defs)
for source1, source2 in self.source_pairs:
# preprocess into a union-find forest
self._union(self._find(source1), self._find(source2))
for source, root, fn in self.derived_equalities:
# preprocess into a transitively-closed map
# NOTE(avik): we reuse the union-find forest for canonicalizing input sources
if isinstance(root, sympy.Symbol):
self._defs[self._find(source)] = fn(root)
else:
self._defs[self._find(source)] = fn(self._rewrite(root))
def _find(self, source):
# chase edges to find the root of this equivalence class
if source in self._parents:
return self._find(self._parents[source])
else:
return source
def _union(self, root1, root2):
# merge two equivalence classes by adding an edge from one root to the other
if root1 != root2:
self._parents[root1] = root2
def _rewrite(self, src):
# always represent the given source by the root of its equivalence class
src = self._find(src)
if src in self._defs:
# simply look up the definition if it exists
# NOTE(avik): This works because definitions are always transitively-closed;
# otherwise we would have to do recursive rewriting.
return self._defs[src]
else:
# otherwise, create a symbol representing the source
return sympy.Symbol(src.name())
def is_equal(self, source1, source2):
return (
# check whether source1 and source2 have the same root
self._find(source1) == self._find(source2) or
# check whether source1 is derived equal to source2
self.is_derived(source1, source2, lambda x: x)
)
def is_derived(self, src, symbol_src, fn):
# check whether both src and symbol_src have the same definition
return self._rewrite(src) == fn(self._rewrite(symbol_src))
def _assert_symbol_context(symbolic_context):
assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object"
assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC"
def _is_supported_equivalence(expr):
# Currently supported Dim ops are linear expressions with integer coefficients.
# So check that expr only contains +, *, ints, and a single occurrence of a symbol.
# (See also documentation of dynamic_shapes._DerivedDim.)
if isinstance(expr, (sympy.Add, sympy.Mul)):
if len(expr.args) > 2:
return False
lhs, rhs = expr.args
return (
(_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or
(isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs))
)
return isinstance(expr, sympy.Symbol)
[docs]@dataclass(frozen=True)
class SymbolicContext:
"""
Data structure specifying how we should create symbols in
``create_symbolic_sizes_strides_storage_offset``; e.g., should
they be static or dynamic.
This is an abstract base class because we are probably going to add
another version of this that says "use exactly these SymInts, don't
allocate fresh symbols."
"""
pass
[docs]@dataclass(frozen=True)
class StatelessSymbolicContext(SymbolicContext):
"""
Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``.
This will cause fresh symbols to be allocated
"""
dynamic_sizes: DimList[DimDynamic]
constraint_sizes: DimList[DimConstraint] = None
# If the tensor is a view, this should be populated for the base. It contains
# information on how to allocate symbols when recursively fakeifying the base
# during view fake-ification.
view_base_context: Optional[SymbolicContext] = None
# TODO: add storage offset and stride symbolic_context
def __post_init__(self):
if self.constraint_sizes is None:
object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes))
# note [Tensor Fakification and Symbol Caching]
#
# As of the time of this note, dynamo creates a fresh fake tensor mode for backends.
# The reason we do this is because there are certain classes of operations, namely,
# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor
# state at the end of a dynamo trace is different than the fake tensor state at the beginning
# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation,
# view relationships, etc.
#
# As we create a new fake mode, we also lose the memoization that comes with it. Rather than
# transfer the memoization cache, we instead transfer the shape env. However, with this
# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in
# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across
# recompilations.
#
# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass
# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext.
# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is
# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors
# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env
# is used.
# TODO(voz): Shape env validation
[docs]@dataclass(frozen=True)
class StatefulSymbolicContext(StatelessSymbolicContext):
"""
Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
a symbolic_context determination as given by a cache of Source:Symbol. A cache hit
will reuse a stored symbol, and a cache miss will write to this cache.
This behaves like StatelessSymbolicContext, except the cache supersedes the
other values - dynamic_sizes and constraint_sizes will not be read if we cache
hit.
It is the cache owners responsibility to maintain the lifecycle of the cache
w/r/t different shape_envs, clearing, etc.
"""
tensor_source: Source = None
# Why is this keyd on int first?
# That integer is actually the id of the shape_env. This cache short-circuits symbol
# creation, and we must store it per shape env. Now, while tracing invariants are a single
# shape env per tracing context, and every new frame gets a new shape_env. So where would we have
# multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events
# is invoked, and creates a new shape_env. Replaying events against this new shape_env will
# cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
# get recorded in var_to_val, etc.
# TODO(voz): consider a weakref to the shape_env here
shape_env_to_source_to_symbol_cache : Dict[int, Dict["TensorPropertySource", "sympy.Expr"]] = None
def __post_init__(self):
# The None default is annoying, but required because of dataclass limitations
assert self.tensor_source is not None
if not self.shape_env_to_source_to_symbol_cache:
object.__setattr__(self, 'shape_env_to_source_to_symbol_cache', {})
[docs]@dataclass(frozen=True)
class SubclassSymbolicContext(StatefulSymbolicContext):
"""
The correct symbolic context for a given inner tensor of a traceable tensor subclass
may differ from that of the outer symbolic context. This structure allows for this
flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
"""
inner_contexts: Dict[str, SymbolicContext] = None
def __post_init__(self):
super().__post_init__()
if self.inner_contexts is None:
self.inner_contexts = {}
def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
if isinstance(val, (int, float, bool)):
return False
return val.node.is_symbolic()
IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
@lru_cache(256)
def safe_expand(r):
if hasattr(r, 'expand'):
try:
return sympy.expand(r)
except RecursionError:
log.warning("RecursionError in sympy.expand(%s)", r)
return r
else:
return r
def error():
raise AssertionError("shouldn't be hit")
# TODO: Deduplicate this with torch/_prims_common/__init__.py
def eval_is_non_overlapping_and_dense(sizes, strides):
return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides)))
def _eval_is_non_overlapping_and_dense(sizes, strides):
dim = len(sizes)
# Short-circuits for tensors of rank one, which are
# non-overlapping and "dense" if their stride is one
# or it is a 0/1 element tensor
if dim == 1:
return strides[0] == 1 or sizes[0] < 2
# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
# Sorts (length, stride) pairs by stride
lengths_and_strides = sorted(
zip(sizes, strides), key=operator.itemgetter(1)
)
# Unlike the C++ code, we don't move the 0/1 size dimensions to the
# end. So we have to keep going for this code.
expected_stride = 1
for length, stride in lengths_and_strides:
if length == 1:
continue
if stride != expected_stride:
return False
expected_stride *= length
return True
def _sympy_cast_symbool_to_symint_guardless(x: sympy.Expr) -> sympy.Expr:
return sympy.Piecewise((1, x), (0, True))
def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
if isinstance(symbool, bool):
return 1 if symbool else 0
int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr)
return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None)
SYMPY_INTERP = {
'Abs': operator.abs,
'Eq': operator.eq,
'Ne': operator.ne,
'Gt': operator.gt,
'Lt': operator.lt,
'Le': operator.le,
'Ge': operator.ge,
'Min': min,
'Max': max,
'Mod': operator.mod,
'PythonMod': operator.mod,
'FloorDiv': operator.floordiv,
'TrueDiv': operator.truediv,
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor,
'ceiling': math.ceil,
'FloorToInt': math.floor,
'CeilToInt': math.ceil,
'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
'RoundToInt': builtins.round,
'RoundDecimal': builtins.round,
'TruncToInt': math.trunc,
'IntTrueDiv': operator.truediv,
}
def _lru_cache(fn, maxsize=None):
"""
Wrapper around lru_cache that clears when new info about shapes has been
updated.
Use lru_cache if the output is always the same, regardless of the
constraints we know now (i.e. evaluate_expr)
Use _lru_cache otherwise.
Also note that this depends on _update_version_counter being called on the
shape environment whenever the constraints are updated, otherwise the cache
will not be cleared.
"""
fn_cache = lru_cache(maxsize)(fn)
prior_version = 0
if config.validate_shape_env_version_key:
prior_key = None
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
nonlocal prior_version, prior_key
if prior_key is None:
prior_key = self._get_key()
if prior_version != self._version_counter:
fn_cache.cache_clear()
prior_version = self._version_counter
prior_key = self._get_key()
else:
assert prior_key == self._get_key(), \
"ShapeEnv cache key changed without version being updated!"
return fn_cache(self, *args, **kwargs)
else:
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
nonlocal prior_version
if prior_version != self._version_counter:
fn_cache.cache_clear()
prior_version = self._version_counter
return fn_cache(self, *args, **kwargs)
wrapper.cache_clear = fn_cache.cache_clear
wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined]
return wrapper
# This is pretty similar to ShapeGuard but it also comes with a message,
# and is exclusively used for things that MUST be true (unlike guards,
# which can evaluate False, in which case you just choose not to use
# a particular specialization)
@dataclass(frozen=True)
class RuntimeAssert:
expr: sympy.Expr
msg: str = field(repr=False)
stack: str = field(repr=False)
class ShapeGuardPrinter(StrPrinter):
def __init__(
self,
symbol_to_source,
source_ref,
var_to_sources,
):
super().__init__()
self.symbol_to_source = symbol_to_source
self.source_ref = source_ref
self.var_to_sources = var_to_sources
def _print_Not(self, expr):
return 'not {}'.format(self.parenthesize(expr.args[0], PRECEDENCE["Not"]))
def _print_And(self, expr):
return self.stringify(expr.args, " and ", PRECEDENCE["And"])
def _print_Or(self, expr):
return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
def _print_Symbol(self, expr) -> str:
assert isinstance(expr, sympy.Symbol), str(type(expr))
def repr_symbol_to_source():
return repr({
symbol: [s.name() for s in sources]
for symbol, sources in self.symbol_to_source.items()
})
assert self.symbol_to_source.get(expr), (
f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
f"not in {repr_symbol_to_source()}. If this assert is failing, it could be "
"due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
)
return self.source_ref(self.symbol_to_source[expr][0])
class LoggingShapeGuardPrinter(ShapeGuardPrinter):
def __init__(self, var_to_sources):
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
class DynamicDimConstraintPrinter(StrPrinter):
"""
Printer for dynamic dim constraints.
- Instead of t.size()[d] it prints dynamic_dim(t, d)
- Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc.
We use this to suggest code for specifying dynamic dim constraints.
"""
def __init__(self, symbol_to_source, source_name_to_debug_name):
super().__init__()
self.symbol_to_source = symbol_to_source
self.source_name_to_debug_name = source_name_to_debug_name
def print_source(self, source) -> str:
if self.source_name_to_debug_name:
return source.name()
return f"dynamic_dim({source.base.name()}, {source.idx})"
def _print_Symbol(self, expr) -> str:
assert isinstance(expr, sympy.Symbol), str(type(expr))
assert self.symbol_to_source.get(expr), (
f"Unknown symbol {expr} created by constraints solver"
)
return self.print_source(self.symbol_to_source[expr][0])
def _print_Relational(self, expr):
return f'{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}'
[docs]class DimConstraints:
"""
Custom solver for a system of constraints on symbolic dimensions.
Solutions are "static" values or simplified "dynamic" constraints.
"""
def __init__(
self,
symbol_to_source,
var_to_val,
marked_dynamic,
source_name_to_debug_name,
_allow_complex_guards_as_runtime_asserts=False,
):
# We try to solve systems of inequalities with 1 free variable.
self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
# Among them, we prioritize solving for a free variable that has equalities.
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
# and removing a symbol from the former => removing it from the latter.
self._symbols_with_equalities: Set[sympy.Symbol] = set()
# A solution of a free variable with equalities becomes a substitution.
# We use these substitutions to simplify other constraints.
# NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {}
# In general, constraints may have // and % operations.
# Of course, // can be expressed in terms of / and %.
# Our inequality solver can handle / but not %. So we need to transform them away.
# We do so by using the values of variables as hints to evaluate %.
# For soundness we record additional congruence guards and solve them separately.
self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val
self._congruences: Set[sympy.Expr] = defaultdict(set)
# We do not try to (directly) solve inequalities with > 1 free variables.
# NOTE: free variables in these inequalities cannot also be in _substitutions.
self._multivariate_inequalities: Set[sympy.Expr] = set()
# We park external equalities between free variables here.
self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = []
# Solutions come in two forms:
# - (static) specializations
# - (dynamic) inequalities / congruences
self._static_results: Set[str] = set()
self._dynamic_results: Set[str] = set()
# printer for solutions
self._dcp = DynamicDimConstraintPrinter(symbol_to_source, source_name_to_debug_name)
# inconsistencies found on substituting with concrete values / static solutions
self._inconsistencies: List[str] = []
# symbols that are marked dynamic
self._marked_dynamic = marked_dynamic
# for constraints we can't express with the dynamic shapes language, defer as runtime asserts in export
self._allow_complex_guards_as_runtime_asserts = _allow_complex_guards_as_runtime_asserts
[docs] def rewrite_with_congruences(self, s, expr):
"""
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
This leaves rational operators (in particular of the form b / d) that our inequality solver can handle.
We solve the added congruences separately (using our congruence solver, see below).
"""
def mod_handler(*args):
# Suppose that we have an expression of the form b % d with free variable s.
# Using the value of s as a "hint," we can evaluate b % d to a value k.
# Then we can rewrite b % d to k while adding the guard b % d == k.
# NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF
# the original expression always evaluates to a constant value (i.e., it does not vary with s).
# In other words,
# - solutions of s with the rewritten expression are guaranteed to also be solutions of s with
# the original expression;
# - while it may be possible to find solutions of s with the original expression that are not
# solutions with the rewritten expression, in that case the original expression cannot evaluate
# to the same value for all solutions of s.
#
# Should we be worried about this incompleteness? No, because of the following reasons:
# 1. It unblocks dramatic simplification that would not be otherwise possible with current tech
# (i.e., "don't let perfect be the enemy of the good").
# 2. We already have a tradition of using hints to add guards in the compiler for making progress.
# 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards
# we generate (or simplify to) seem to be of the form b % d == k where k is a constant.
#
# Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2.
# With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
# would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
base, divisor = args
base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val)
congruence = (base - mod_reduced) % divisor
if congruence != 0:
self._congruences[s].add(congruence)
return mod_reduced
def floor_div_handler(*args):
# Suppose that we have an expression of the form b // d with free variable s.
# Using the value of s, we can evaluate b % d to a value k.
# Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k.
# NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
# and eliminating b % d as above.
base, divisor = args
base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val)
congruence = (base - mod_reduced) % divisor
if congruence != 0:
self._congruences[s].add(congruence)
# NB: Must not be CleanDiv, it needs to be regular sympy division
# so inequality solver works. This is sort of problematic for
# is_integer tests though haha
return (base - mod_reduced) / divisor
if expr.has(Mod):
expr = expr.replace(Mod, mod_handler)
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
# arguments should be OK.
if expr.has(PythonMod):
expr = expr.replace(PythonMod, mod_handler)
if expr.has(FloorDiv):
expr = expr.replace(FloorDiv, floor_div_handler)
return expr
[docs] def add(self, expr) -> bool:
"""Add an expression to the set of constraints.
Return whether the expression is a trivial constraint (i.e., an obvious tautology).
"""
if expr == sympy.true:
return True
orig_expr = expr
orig_reduced = orig_expr.xreplace(self._var_to_val)
# TODO(avik): https://github.com/pytorch/pytorch/issues/101093
# It is possible that `expr` will fail the consistency check because of
# precision errors. Specifically, on substituting its free symbols with
# their concrete values, we might end up comparing floats. Until we have
# a fix for this issue, we delay raising such failures. See solve().
if orig_reduced == sympy.false:
self._inconsistencies.append(f"{orig_expr} is inconsistent!")
if isinstance(expr, sympy.Ne):
# we're not going to do anything useful with these, so drop them
return False
free_symbols = expr.free_symbols
assert free_symbols, f"Did not expect constraint with no free variables: {expr}"
if len(free_symbols) > 1:
# multivariate: record and move on
self._multivariate_inequalities.add(expr)
else:
# univariate: can solve these immediately
s = next(iter(free_symbols))
# eliminate // and % (see documentation of `rewrite_with_congruences` above)
old_n_congruences = len(self._congruences[s])
expr = self.rewrite_with_congruences(s, expr)
new_n_congruences = len(self._congruences[s])
if expr == sympy.true:
return old_n_congruences == new_n_congruences
reduced = expr.xreplace(self._var_to_val)
if reduced == sympy.false:
self._inconsistencies.append(
f"{expr}, obtained by rewriting {orig_expr} with congruences, "
"is inconsistent!"
)
if isinstance(expr, sympy.Eq):
# special status for symbols that have equalities (see `solve` below)
self._symbols_with_equalities.add(s)
self._univariate_inequalities[s].add(expr)
return False
[docs] def add_equality(self, source, expr):
"""Add an equality constraint"""
if expr.is_number:
# specialization, right here
self._static_results.add(f"{source.name()} == {expr}")
else:
# these will resolve to either specializations or dynamic equality constraints
self._symbolic_equivalences.append((source, expr))
def _reduce_congruences(self):
reduced_congruences = {}
for s, congruences in self._congruences.items():
remainder_modulus_pairs = []
congruences_to_check = set()
for congruence in congruences:
base, divisor = congruence.args
# We are given a congruence of the form base % divisor == 0 with a free variable s. So:
# - we transform this into an equation of the form base = divisor * tmp;
# - we solve this equation for s to get a linear solution with free variable tmp.
tmp = sympy.Symbol("reduce_congruences_tmp", integer=True)
symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s])
# See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear
# for how to interpret the results.
if s == symbol:
# This means the solution is of the form s = modulus*tmp + remainder.
modulus, remainder = sympy.polys.polytools.div(solution, tmp)
if isinstance(modulus, sympy.Integer) and isinstance(remainder, sympy.Integer):
# Make sure 0 <= remainder <= modulus.
remainder = remainder % modulus
remainder_modulus_pairs.append((remainder, modulus))
continue
# This means that we did not get a unique solution to the equation.
# No problem, we will check it.
congruences_to_check.add(congruence)
# Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i).
# The solution will be a congruence of the form s = r mod m.
# NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT.
if remainder_modulus_pairs:
remainder, modulus = sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs)
reduced_congruences[s] = {(s - remainder) % modulus}
substitution = {s: modulus * sympy.Symbol("tmp", integer=True) + remainder}
reduced_congruences[s].update(
congruence for congruence in congruences_to_check
if not sympy.checksol(congruence, substitution)
)
else:
reduced_congruences[s] = congruences_to_check
return reduced_congruences
def _raise_inconsistencies(self):
if self._inconsistencies:
msg = "\n".join(self._inconsistencies)
self._inconsistencies.clear()
raise ValueError(f"The following inconsistencies were found:\n{msg}")
def _force_specialization(self, s):
val = self._var_to_val[s]
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
self._substitutions[s] = val
def _specialize_divisor_symbols(self):
for expr in self._multivariate_inequalities:
for atom in expr.atoms(FloorDiv, Mod):
_, divisor = atom.args
for s in divisor.free_symbols:
self._force_specialization(s)
multivariate_inequalities = self._multivariate_inequalities
self._multivariate_inequalities = set()
for expr in multivariate_inequalities:
self.add(expr.xreplace(self._substitutions))
self._raise_inconsistencies()
self._univariate_inequalities = {
s: exprs
for s, exprs in self._univariate_inequalities.items()
if s not in self._substitutions
}
self._congruences = {
s: congruences
for s, congruences in self._congruences.items()
if s not in self._substitutions
}
[docs] def solve(
self,
_disable_forced_specializations=False,
):
"""Solve the system of constraint equations to find simplified constraints
"""
self._raise_inconsistencies()
# as long as there are symbols with equalities, solve for them
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
while self._symbols_with_equalities:
s = self._symbols_with_equalities.pop()
exprs = self._univariate_inequalities.pop(s)
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
if isinstance(solution, sympy.And):
solution = next((arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution)
assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}"
symbol, val = solution.args
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
# really don't force specializations here
if not (_disable_forced_specializations and s in self._marked_dynamic):
# because this is univariate, the solution is a specialization
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
# add this as a substitution to simplify other constraints
self._substitutions[s] = val
# simplify multivariate inequalities: some of them will now become univariate!
multivariate_inequalities = self._multivariate_inequalities
self._multivariate_inequalities = set()
for expr in multivariate_inequalities:
self.add(expr.xreplace({s: self._substitutions[s]}))
self._raise_inconsistencies()
if not _disable_forced_specializations:
self._specialize_divisor_symbols()
# solve linear congruences
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
reduced_congruences = self._reduce_congruences()
for s, congruences in reduced_congruences.items():
for congruence in congruences:
# any congruence that cannot be checked becomes a dynamic constraint as well
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
if self._is_supported_congruence(congruence):
base, divisor = congruence.args
tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
tmp = sympy.Symbol(tmp_name, integer=True)
from torch._dynamo.source import ConstantSource
self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)]
r = try_solve(sympy.Eq(base, divisor * tmp), s)
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1])))
elif not _disable_forced_specializations:
self._force_specialization(s)
self._univariate_inequalities.pop(s, None)
# remaining symbols have only pure inequalities (no equalities)
for s, exprs in self._univariate_inequalities.items():
try:
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
# because this is univariate, the solution is a dynamic (range) constraint
if isinstance(solution, sympy.Or):
solution = next(iter(arg for arg in solution.args if arg.xreplace(self._var_to_val)))
if isinstance(solution, sympy.And):
for arg in solution.args:
self._dynamic_results.add(self._dcp.doprint(arg))
else:
self._dynamic_results.add(self._dcp.doprint(solution))
except (NotImplementedError, AssertionError) as e:
log.warning("Failed to reduce inequalities: %s", e)
for expr in exprs:
self._dynamic_results.add(self._dcp.doprint(expr))
# simplify symbolic equivalences: some of them will now become specializations!
symbolic_equivalences = self._symbolic_equivalences
self._symbolic_equivalences = []
for source, expr in symbolic_equivalences:
if not _disable_forced_specializations and not _is_supported_equivalence(expr):
for s in expr.free_symbols:
self._force_specialization(s)
sexpr = self._dcp._print_Symbol(s)
self._dynamic_results = {r for r in self._dynamic_results if sexpr not in r}
self.add_equality(source, expr.xreplace(self._substitutions))
# remaining symbolic equivalences become dynamic equality constraints
for source, expr in self._symbolic_equivalences:
self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}")
@classmethod
def _is_supported_congruence(cls, congruence):
base, divisor = congruence.args
# Congruences that can be currently expressed with supported Dim ops are
# of the form (x + a) % b == 0, where x is a Dim and a and b are constants.
# This allows us to derive x as b*y - a for some Dim y.
# (See also documentation of dynamic_shapes._DerivedDim.)
if isinstance(base, sympy.Add):
lhs, rhs = base.args
cond = (
(isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)) or
(isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol))
)
else:
cond = isinstance(base, sympy.Symbol)
cond = cond and isinstance(divisor, sympy.Integer)
return cond
[docs] def forced_specializations(self):
"""Returns a dictionary of the names of symbols to their specialized value
"""
def debug_name(src):
name = src.name()
if self._dcp.source_name_to_debug_name:
return f"{self._dcp.source_name_to_debug_name[name]} = {name}"
else:
return name
return {
debug_name(self._dcp.symbol_to_source[s][0]): val
for s, val in self._substitutions.items()
if s in self._marked_dynamic
}
[docs] def remove_redundant_dynamic_results(self):
"""Remove constraints of the form 2 <= dynamic_dim(...) as 2 is the default
lower bound.
"""
candidates_for_removal = []
dynamic_results = set()
for dc in self._dynamic_results:
# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
# There is no change in behavior since 2 is the default lower bound.
dc_ = re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)
if dc != dc_:
candidates_for_removal.append(dc_)
else:
dynamic_results.add(dc_)
for dc in candidates_for_removal:
# remove dynamic_dim(t, 0) as a constraint when dynamic_dim(t, 0) also
# appears as part of another constraint
found = False
for other_dc in dynamic_results:
if dc in other_dc:
found = True
if not found:
dynamic_results.add(dc)
self._dynamic_results = dynamic_results
def _is_derived_dim(self, dim):
return isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
def _is_dim(self, dim):
return (
isinstance(dim, torch.export.dynamic_shapes._Dim)
and not isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
)
def _process_derived_dim_roots(
self,
results: Dict[str, Dict[str, Any]],
name_to_dim: Dict[str, Any],
) -> None:
'''
Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots,
and 2) root swapping.
1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests
dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final
suggested fixes handle this correctly, but we can get intermediate results that look like
{"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}}
and this routine prettifies this by unifying to a single root, and making each suggestion
either a derived dim or min/max range, not both.
2) With suggested fixes for derived dims, roots can be swapped,
e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name,
since this leads to messages like "dx - 1 = Dim("dx - 1", ...)".
Instead we evaluate the new root value, and remove results for its derivations.
First we find all the original roots (specified in dynamic_shapes), that are found in the
values of results (i.e. used for computing suggesting fix values). These original roots
(suppose `dx`) are either specialized, unchanged, refined, or swapped
(expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value
in results, and remove suggestions for derivations of `dx`, assuming the derived relation
is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value,
and then do the same with `dx`'s derivations.
Assuming the originally specified derived relations are correct is valid, because:
1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1))
produce_guards() will catch this and crash before hand.
2) if the relations are numerically correct but do not match the emitted guard,
for example:
def forward(self, x, y):
return x.reshape([-1]) + y # guard: s0 * 2 = s1
inputs = (torch.randn(6, 2), torch.randn(12))
dx = Dim("dx", min=2, max=32)
dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op
then this leads to 2 linear equations, and a) produce_guards() is able to solve for
the unique solution of dx = 6 and specialize, and b) the export constraint solver will
raise an issue due to range constraints (a unique solution means not all values in a
range satisfy a guard) and also force specializations.
'''
from torch.export.dynamic_shapes import Dim
def _check_same_range(c, dim):
# returns True if c & dim are both min/max ranges with same values
return (
self._is_dim(dim)
and ("min" in c or "max" in c)
and (
(dim.min < 2 and c.get("min", 2) == 2)
or dim.min == c.get("min", 2)
) # let pass if analysis min = 2 and specified min = 0/1
and dim.max == c.get("max", sys.maxsize - 1)
)
# 1) newly introduced roots
# this part we handle adding newly introduced roots
# these arise from guards like "x.shape[0] % 3 == 0"
# leading to suggested fixes like "dx = 3*_dx"
# extract _dx, and find appropriate min/max values
#
# before, we have something like:
# {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2}
# we want instead:
# {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3}
introduced_roots: Dict[str, str] = {} # map new root -> old root
for k, c in list(results.items()):
if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim
root = next(iter(c["eq"].free_symbols))
if str(root) not in name_to_dim:
introduced_roots[str(root)] = k
# calculate necessary min & max
modulus, remainder = sympy.polys.polytools.div(c["eq"], root)
c_min = c.get("min", 2)
min_ = math.ceil((c_min - remainder) / modulus)
c_max = c.get("max", sys.maxsize - 1)
max_ = math.floor((c_max - remainder) / modulus)
# create result & dim
results[str(root)] = {"min": min_, "max": max_}
name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_)
# remove old root min/max bounds
c.pop("min", None)
c.pop("max", None)
# alter derivations that depend on old root, to unify to new root
# e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2
for old_root in introduced_roots.values():
for k, c in list(results.items()):
if (
"eq" in c
and isinstance(c["eq"], sympy.Expr)
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
): # derived dim with root = old_root
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
c["eq"] = new_expr
# 2) root swapping
# collect all the original roots that are used for calculating values of suggested fixes
# this consists of:
# 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim
# 2) {"dy": "dx + 1"} -> dx: root for suggested fix
modified_roots: Set[str] = set()
for k, c in results.items():
if k not in name_to_dim: # _dynamo.export() may handle source directly
continue
if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1)
modified_roots.add(k)
elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2)
root = next(iter(c["eq"].free_symbols))
assert root is not None
modified_roots.add(str(root))
# exclude newly introduced roots, we've already processed these
modified_roots = modified_roots.difference(introduced_roots)
# evaluate the new value for each root
# this is now either 1) unchanged, 2) refined with a new range,
# or 3) specialized to a concrete value
modified_root_values: Dict[str, Dict[str, Any]] = {}
for root in modified_roots:
swapped_root = True
if root in results:
c = results[root]
if (
("min" in c or "max" in c) # range
or isinstance(c["eq"], int) # specialized
):
# here, the original root is a root Dim or concrete value in results.
# if it is a derived dim, it is swapped, and we handle that below.
if not _check_same_range(c, name_to_dim[root]): # ignore if unchanged
modified_root_values[root] = c
swapped_root = False
if swapped_root:
# if the original root has been swapped in results, that means the new root
# is a range (if it had specialized, the original root would have too).
# find this new root, and solve for the original root's range.
for k, c in results.items():
if k not in name_to_dim:
continue
dim = name_to_dim[k]
if dim.__class__.__name__ == "_DerivedDim" and dim.root.__name__ == root:
# only look for min/max root, otherwise root would have specialized
if "min" in c or "max" in c:
expr = sympy.sympify(k)
s = next(iter(expr.free_symbols))
result = {
"min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type]
"max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type]
}
if not _check_same_range(result, name_to_dim[root]): # ignore if unchanged
modified_root_values[root] = result
break
# filter out results where the key is a derived dim (e.g. {"dx - 1" : 4})
# we only want to suggest fixes for the root, to avoid derived names.
# also, remove anything in modified_roots, since we either add new modified values after this,
# or have decided they are unchanged.
for k in list(results.keys()):
if k not in name_to_dim:
continue
if self._is_derived_dim(name_to_dim[k]) or k in modified_roots:
del results[k]
# update results with modified root values
# now results has the following properties:
# - only contains original roots as keys
# - each root is now either specialized, refined, or derived from another original root
results.update(modified_root_values)
[docs] def prettify_results(
self,
original_signature: inspect.Signature,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
constraint_violation_error=None,
forced_specializations=None,
):
"""Format a message for constraint violation erros"""
from torch.export.dynamic_shapes import _get_dim_name_mapping
if self._dcp.source_name_to_debug_name:
def transform(s, inverse=False):
for k, v in self._dcp.source_name_to_debug_name.items():
s = s.replace(k, v) if not inverse else s.replace(v, k)
return s
results = defaultdict(dict)
if dynamic_shapes is None:
dynamic_shapes = {}
def flip(op):
if op == "<=":
return ">="
if op == ">=":
return "<="
if op == "<":
return ">"
if op == ">":
return "<"
assert op == "=="
return op
def relation_with_digit(expr, op, digit):
if op == "<=":
results[expr]["max"] = digit
elif op == "<":
results[expr]["max"] = digit - 1
elif op == ">=":
results[expr]["min"] = digit
elif op == ">":
results[expr]["min"] = digit + 1
else:
assert op == "=="
results[expr]["eq"] = digit
# retrieve dynamic shapes
name_to_dim = _get_dim_name_mapping(dynamic_shapes)
for s in self._static_results.union(self._dynamic_results):
t = transform(s)
if t == s:
continue
left, op, right = re.split(r"( == | <= | >= | < | > )", t)
op = op.strip()
if op == "==" and left == right:
continue
if right.isdigit():
relation_with_digit(left, op, int(right))
elif left.isdigit():
relation_with_digit(right, flip(op), int(left))
else:
assert op == "==", t
results[left]["eq"] = sympy.sympify(right)
# order forced specializations based on name
forced_specializations = {
k: forced_specializations[k]
for k in sorted(
forced_specializations.keys(),
key=lambda x: x.split(" = ")[1],
)
}
buf = ""
if forced_specializations:
debug_names = set()
for k in forced_specializations:
dim = name_to_dim[k.split(" = ")[0]]
if self._is_derived_dim(dim):
debug_names.add(dim.root.__name__)
else:
debug_names.add(dim.__name__)
buf += (
f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! "
'For more information, run with TORCH_LOGS="+dynamic".\n'
)
for s, val in forced_specializations.items():
buf += f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n"
self._process_derived_dim_roots(results, name_to_dim)
dims = []
others = []
# order results by source name
results = {
k: results[k] for k in sorted(
results.keys(),
key=lambda x: transform(x, inverse=True),
)
}
for k, c in results.items():
if "eq" in c:
other = c["eq"]
if isinstance(other, int):
others.append(f"{k} = {other}")
elif _is_supported_equivalence(other):
others.append(f"{k} = {other}")
else:
min_ = c.get("min", None)
if min_ == 2:
min_ = None
max_ = c.get("max", None)
if min_ is not None and max_ is not None:
dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
elif min_ is not None:
dims.append(f"{k} = Dim('{k}', min={min_})")
elif max_ is not None:
dims.append(f"{k} = Dim('{k}', max={max_})")
else:
dims.append(f"{k} = Dim('{k}')")
# results will get filtered out if no new suggestions,
# this can happen if guards are too complex.
# in that case don't suggest fix
if dims or others:
buf += "\nSuggested fixes:\n "
buf += "\n ".join(dims + others)
return buf
# Note: Model inputs are wrapped as LocalSource in dynamo.
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def extract_and_rewrite_local(dc):
match = re.search(r"L\['(.+?)'\]", dc)
if match is None:
return
arg = match.expand(r'\1')
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
return arg, dc
def group(results, args_index):
groups = defaultdict(list)
for dc in results:
local = extract_and_rewrite_local(dc)
if local is None:
# This can happen, e.g., with `assume_constant_result`.
# In that case, we drop the constraint.
# TODO(avik) Maybe we should generate an assertion here?
continue
arg, dc = local
if arg in args_index:
groups[args_index[arg]].append(dc)
else:
# This can happen, e.g., with decorators that change the signature.
# In that case, we drop the constraint. Seems hard to do better. :/
# TODO(avik) Maybe warn that `arg` in not in `signature`?
continue
sorted_groups = []
for idx, dcs in sorted(groups.items()):
_, arg = idx
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
args_index[arg] = (i, arg)
def print_results(grouped, indent, result_fn):
nonlocal buf
space = False
for arg, results in grouped:
if space:
buf += "\n"
else:
space = True
buf += f"\n{indent}# {arg}:"
for result in results:
buf += f"\n{indent}{result_fn(result)}"
buf = ""
if forced_specializations:
buf += (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
)
for s, val in forced_specializations.items():
buf += f" - {s}, which was marked dynamic, must be specialized to {val}.\n"
indent = 4 * " "
if self._static_results:
grouped_static_results = group(self._static_results, args_index)
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += f"\n```\ndef specializations{str(signature)}:"
print_results(
grouped_static_results,
indent,
lambda result: f"assert {result}",
)
buf += "\n```\n"
if self._dynamic_results:
grouped_dynamic_results = group(self._dynamic_results, args_index)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nPlease use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
print_results(
grouped_dynamic_results,
indent * 2,
lambda result: f"{result},",
)
buf += f"\n{indent}]\n```\n"
return buf
TLS = threading.local()
[docs]@dataclass(frozen=True)
class ShapeEnvSettings:
"""
Encapsulates all shape env settings that could potentially affect
FakeTensor dispatch. Used when creating dispatch cache keys.
"""
allow_scalar_outputs: bool
allow_dynamic_output_shape_ops: bool
assume_static_by_default: bool
specialize_zero_one: bool
duck_shape: bool
prefer_deferred_runtime_asserts_over_guards: bool
_allow_complex_guards_as_runtime_asserts: bool
[docs]class ShapeEnv:
# This is a wrapper over the actual __init__ function.
#
# Where to add a new constructor parameter to ShapeEnv?
# =====================================================
# This __init__ function should be used only for parameters related to event recording.
# These are parameters that we don't wish to pass down the road to new ShapeEnv instances
# created from replaying events.
#
# If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event
# recording, do so in the _init function.
def __init__(
self, *,
should_record_events: Optional[bool] = None,
tracked_fakes: Optional[List[Any]] = None,
**kwargs
) -> None:
self._init(**kwargs)
# Disable event recording when replaying.
kwargs["should_record_events"] = False
from torch.fx.experimental.validator import translation_validation_enabled
self._translation_validation_enabled = translation_validation_enabled()
# If not specified, enable event recording if both:
# - Translation validation is on
# - Translation validation bisection is not disabled
self.should_record_events = (
should_record_events
if should_record_events is not None
else (
self._translation_validation_enabled
and not config.translation_validation_no_bisect
)
)
# Enable event recording check if both:
# - It should record events
# - The recording check is enabled
self.check_recorded_events = (
self.should_record_events and config.check_shape_env_recorded_events
)
# This will make sure we only record the top-level function call.
self.is_recording = not self.should_record_events
# Keep track of the list of tracked fakes.
self.tracked_fakes = tracked_fakes
# List of events for reconstructing ShapeEnv at arbitrary points in time.
self.events: List[ShapeEnvEvent] = (
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else []
)
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
# tests. Accept their output with:
#
# EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal
#
def _init(
self, *,
allow_scalar_outputs=True,
allow_dynamic_output_shape_ops=True,
# NB: These are legacy configuration that help us make good choices
# when the constraint/dynamic dims are not explicitly passed to us.
# Ideally we will fix all call sites to be explicit and not have
# implicit choices, but this apparently was pretty involved.
assume_static_by_default=False,
# Note - On 0/1 specialization
#
# The following options affect decisions we make about eager
# specialization. Disabling them will increase trace time (as we do
# more symbolic reasoning) and can also harm the quality of generated
# code (because inductor may not be able to specialize for bounds
# being equal--although if we later respecialize because of a guard,
# your code may be just as good as it was before.)
#
# When True, eagerly specialize input sizes which have 0/1.
specialize_zero_one=True,
# When True, assume input sizes which have the same size are
# symbolically equal.
duck_shape=True,
# For debugging
co_fields=None,
# When True, whenever safe, we will generate a deferred runtime assert
# instead of a guard whenever we know that an expression must be True,
# otherwise it would be an error, even for backed SymInts (where we
# could ostensibly unconditionally generate guards). This is useful
# for export, where preventing "error checking" sizes from showing up
# in guards is helpful, since these guards in some sense are overly
# pedantic. See also https://github.com/pytorch/pytorch/issues/121749
prefer_deferred_runtime_asserts_over_guards=False,
# When True, does not emit or raise constraint violation errors on
# implicit guards generated by ops, and defers to runtime assertions
# in the graph instead. For export.
_allow_complex_guards_as_runtime_asserts=False,
# XXX Add any new settings that could affect FakeTensor evaluation
# to: torch._subclasses.fake_tensor._ShapeEnvSettings
):
self.settings = ShapeEnvSettings(
# Not directly used by ShapeEnv; indirectly used by FakeTensor
allow_scalar_outputs=allow_scalar_outputs,
allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops,
# End
assume_static_by_default=assume_static_by_default,
specialize_zero_one=specialize_zero_one,
duck_shape=duck_shape,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts,
)
self.guards: List[ShapeGuard] = []
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
# Like var_to_val, but only set when propagate_real_tensors is on.
# Used as last resort to avoid GuardOnDataDependent error
self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
# Maps symbolic ints to their min/max range. These ranges
# are conservative: the int MUST fall in the range, but the
# range may contain ints which may not actually appear in
# practice
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
self.source_name_to_debug_name: Dict[str, str] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict[sympy.Symbol, sympy.Expr] = {}
self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {}
# Set holds a % b expressions that evaluate to 0.
self.divisible: Set[sympy.Expr] = set()
# Set that holds "size-like" symbols. When we perform
# "size-oblivious" tests, these can be assumed to be >= 2.
self.size_like: Set[sympy.Symbol] = set()
# Duck-shaping says that if two input tensors have the same size,
# they get assigned the same symbolic variable
self.val_to_var: Dict[int, sympy.Expr] = {}
if specialize_zero_one:
self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)}
self.unbacked_symfloat_counter = itertools.count()
self.unbacked_symint_counter = itertools.count()
# Similar to guards, but these MUST evaluate to true and can
# only be evaluated at runtime midway through (i.e., they always
# involve unbacked symints)
#
# For efficiency reasons, we index in the following way. Suppose you have
# a runtime assert i0 + i1 <= s1. We pick the most recently allocated
# symbol in the source expression and add the assert to the list for
# that symbol e.g., {i1: [i0 + i1 <= s1]}.
#
# We access the runtime asserts in two situations:
#
# - When we are guarding on an expression, we will attempt to
# statically evaluate it, in case the unbacked SymInts can
# simplify away. If we have a runtime assert, we may be able
# to discharge the guard entirely. We only need to attempt
# runtime asserts that mention freevars of the expression in
# question.
#
# - When we are performing codegen (in Inductor for eager, or
# when finalizing the export FX graph), we need to know what
# extra runtime asserts to insert. Whenever an unbacked
# SymInt comes into scope, all runtime asserts involving it
# become eligible for insertion (so long as all of their other
# free unbacked symbols are also in scope). We technically
# can handle any choice of key by kicking inexpressible asserts
# to the next unbacked symbol to wait on, but if we choose the
# latest key, an assert will only show up at the moment when
# we can actually codegen it.
self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {}
# This exists so we can efficiently invalidate the cache (it's used as
# part of the cache key); otherwise we'd have to iterate through
# deferred_runtime_asserts to compute its length
self.num_deferred_runtime_asserts = 0
self.log = log
self.log.debug("create_env")
self.frozen = False
self.runtime_asserts_frozen = False
self.dim_constraints: Optional[DimConstraints] = None
self.counter = collections.Counter()
# Mapping from sympy.Symbol to the number of guards which mention this
# symbol
self.symbol_guard_counter = collections.Counter()
# A selection of important fields on co_field; solely used for
# signpost_event
self.co_fields = co_fields if co_fields else {}
# Whenever we allocate a fresh unbacked Symbol, we add it to this
# pending list. Unbacked symbol allocation can occur at unpredictable
# points during meta tensor propagation, but at some point, the we
# have to know what the binding site for an unbacked symbol is, and
# this is computed when we actually place the node in the graph. The
# important thing is that we always actually handle every unaccounted
# for unbacked symbol, so this list helps us keep track of them and
# then make sure they are all accounted for.
#
# We could potentially give rise to errors earlier by lexically
# scoping when we do propagation, and only allowing unbacked symbols
# to be allocated at this point in time. However this is inconvenient
# to do in Dynamo, because fake tensor propagation is far from when we
# analyze binding sites (set_example_value), so we do it in a more
# mutatey way.
#
# NB: fresh unbacked symbols NEVER get substitutions applied to them,
# they are binding sites!
self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = []
# Version counter used to invalidate cached values
self._prev_cache_key = self._get_key()
self._version_counter = 0
# Cache for FX nodes.
# Maps an already built node a tuple of:
# 1. node's target
# 2. list of arguments
# This drastically reduces the size of the FX graph, avoiding
# duplicated nodes.
self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {}
self.source_to_symbol: Dict[str, sympy.Symbol] = {}
# Suppose you want to replace an unbacked symbol with another
# unbacked symbol. This is error prone because you can cause
# references to unbacked symbols to time travel backwards. E.g.,
#
# u1 = x.item()
# ... use of u1 ...
# u2 = y.item()
# u3 = z.item()
# torch._check(u1 == u2 + u3)
#
# If you replace u1 with u2 + u3, then the use of u1 now
# references u2 and u3 prior to them actually being bound at
# runtime.
#
# To control for this, we track the order unbacked symbols
# were allocated, and only allow substitutions if they respect
# the dependency from this order; an unbacked symbol can only
# be substituted with unbacked symbols that come before it in the
# order.
#
# This also imposes an ordering on the unbacked symbol binding
# sites themselves: you are not allowed to reorder unbacked symbol
# bindings. At the moment, this is not tracked, but we potentially
# could track this at the IR level using a higher order operator
# with something like effect token tracking.
self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {}
from torch.fx.experimental.validator import translation_validation_enabled
self._translation_validation_enabled = translation_validation_enabled()
if self._translation_validation_enabled:
from torch.fx.experimental.validator import TranslationValidator
self.validator = TranslationValidator()
self.graph = torch.fx.Graph()
# Create an output graph and start inserting before that.
# This is needed when 'deepcopy'-ing this object.
self.graph.inserting_before(self.graph.output(None))
# Mapping of each node name to the node itself.
#
# This is useful for matching an FX node from a recorded ShapeEnv.graph
# to the FX node of the ShapeEnv we are running the event on.
#
# Whenever you add a node to self.graph, you must add a mapping to this
# variable. Otherwise, the built FX graph on the replayed ShapeEnv will
# not be valid.
self.name_to_node: Dict[str, torch.fx.Node] = {}
@property
def allow_scalar_outputs(self):
return self.settings.allow_scalar_outputs
@property
def allow_dynamic_output_shape_ops(self):
return self.settings.allow_dynamic_output_shape_ops
@property
def assume_static_by_default(self):
return self.settings.assume_static_by_default
@property
def specialize_zero_one(self):
return self.settings.specialize_zero_one
@property
def duck_shape(self):
return self.settings.duck_shape
@property
def prefer_deferred_runtime_asserts_over_guards(self):
return self.settings.prefer_deferred_runtime_asserts_over_guards
@property
def _allow_complex_guards_as_runtime_asserts(self):
return self.settings._allow_complex_guards_as_runtime_asserts
[docs] def check_equal(self, other: "ShapeEnv") -> None:
"""Compare another ShapeEnv for equivalence
"""
# ShapeEnv fields that are not relevant for the outcome of
# ShapeEnv.produce_guards call:
# - Debugging variables
# - Translation validation related variables
# - Events recording related variables
non_state_variable_names = (
"counter",
"log",
"var_to_stack",
"fx_node_cache",
"graph",
"validator",
"check_recorded_events",
"should_record_events",
"is_recording",
"tracked_fakes",
"events",
"source_name_to_debug_name",
"_prev_cache_key",
"_version_counter",
"dim_constraints",
)
# Mapping of the value of each to-be-compared field into the values that
# should actually be compared.
#
# You should modify this if, for example, the field that holds state and
# debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr)
# and the stack when it was added to the set of guards. In order to compare
# it, we throw away the stack information.
def map_value(key: str, value: Any) -> Any:
if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"):
from copy import copy
# For itertools.count(), we compare the next integer returned
# by the count iterators. Not that we need to copy the iterator
# first. Otherwise we are mutating the object.
return next(copy(value))
elif key == "guards":
# Transform the list of ShapeGuard into a list of expressions.
return [g.expr for g in value]
elif key == "deferred_runtime_asserts":
# Transform the list of RuntimeAsserts into a list of expressions.
return {s: [ra.expr for ra in ras] for s, ras in value.items()}
elif key == "name_to_node":
# Compare just the set of keys is the same.
return set(value.keys())
elif key in ["symbol_guard_counter", "pending_fresh_unbacked_symbols"]:
# Skip this for comparisons
return None
return value
shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
def _snapshot_tracked_fakes(self) -> Optional[List[Any]]:
if self.tracked_fakes is None:
return None
from torch._dynamo.variables.builder import TrackedFake
def maybe_transform_fake(fake: TrackedFake):
inner_fake = fake.fake \
if isinstance(fake.fake, (torch.SymInt, torch.SymFloat)) \
else FakeTensorMeta.from_fake(fake.fake)
# Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a
# FakeTensorMeta for two reasons:
# 1. this is all the information we need when recording ShapeEnvEvents.
# 2. it works even if each TrackedFake changes its metadata.
return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type]
return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
def _last_event_index(self) -> int:
return len(self.events) - 1
@contextmanager
def _recording(self):
self.is_recording = True
try:
yield
finally:
self.is_recording = False
@record_shapeenv_event()
def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr):
self._set_replacement(orig_s, new_s, "eliminate_unbacked")
[docs] @record_shapeenv_event()
def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None:
"""Used only when propagate_real_tensors; registers a value for an
unbacked symbol, which can be used last resort to resolve hints."""
self.unbacked_var_to_val[k] = sympy.sympify(v)
# Unlike set_replacement, this records a shapeenv event
@record_shapeenv_event()
def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol):
assert isinstance(orig_s, sympy.Symbol), orig_s
assert isinstance(new_s, sympy.Symbol), new_s
assert free_unbacked_symbols(new_s), new_s
assert free_unbacked_symbols(orig_s), orig_s
if self._ignore_fresh_unbacked_symbols_tls():
return
dest = self.replacements.get(orig_s)
assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}"
self._set_replacement(orig_s, new_s, "rename_unbacked_to")
self.unbacked_renamings[orig_s] = new_s
if dest is not None:
self._set_replacement(new_s, dest, "rename_unbacked_to_dest")
@record_shapeenv_event()
def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None):
if min is None:
min = 0
if max is None:
max = sys.maxsize - 1
if max < min:
raise ValueError(
"Maximum value to constrain_as_size can't be less than the specified min value, "
"received min={min} and max={max}"
)
self.constrain_symbol_range(
a,
compiler_min=min,
compiler_max=max,
)
self.size_like.add(a)
@record_shapeenv_event()
def _constrain_range(self, a: sympy.Expr, min: int, max: int):
if isinstance(a, sympy.Integer):
if not (min <= int(a) <= max):
raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]")
return
assert isinstance(a, sympy.Symbol), "constraining non-Symbols NYI"
# TODO: Shouldn't we install a guard if the symbol is backed? Or is the
# semantics that this is an "unchecked" assert (but it this actually
# something useful? Might be better to restrict only for unbacked
# SymInt).
self.constrain_symbol_range(
a,
compiler_min=min,
compiler_max=max,
)
@record_shapeenv_event()
def _constrain_unify(self, a, b):
"""
Given two SymInts, constrain them so that they must be equal. NB:
this will not work with SymInts that represent nontrivial expressions
(yet!)
"""
# TODO: this does not install a deferred runtime assert yet
# TODO: Maybe dedupe this with _maybe_guard_rel?
# Update Feb 2024: this is extra important to do, this doesn't handle
# unbacked replacements properly nor does it generate deferred runtime
# asserts
if not isinstance(a, SymInt):
if not isinstance(b, SymInt):
assert a == b
else:
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
assert b.node.shape_env is self
self.replacements[b.node.expr] = sympy.Integer(a)
else:
# TODO: Actually, we can support this as long as one of them is a symbol.
# NB: We can't actually do "unification" as our operators are not
# injective
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
assert a.node.shape_env is self
if not isinstance(b, SymInt):
self.replacements[a.node.expr] = sympy.Integer(b)
else:
assert a.node.shape_env is b.node.shape_env
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
new_var = self._find(a.node.expr)
self.replacements[b.node.expr] = new_var
def _ignore_fresh_unbacked_symbols_tls(self):
return getattr(TLS, "ignore_fresh_unbacked_symbols", False)
@record_shapeenv_event()
def _ignore_fresh_unbacked_symbols_enter(self):
TLS.ignore_fresh_unbacked_symbols = True
@record_shapeenv_event()
def _ignore_fresh_unbacked_symbols_exit(self):
TLS.ignore_fresh_unbacked_symbols = False
[docs] @contextmanager
def ignore_fresh_unbacked_symbols(self):
"""
Indicates that the newly allocated unbacked SymInts are being
discarded
"""
self._ignore_fresh_unbacked_symbols_enter()
try:
yield
finally:
self._ignore_fresh_unbacked_symbols_exit()
[docs] @record_shapeenv_event()
def freeze(self):
"""Freeze this ShapeEnv to stop accumulating guards
A frozen ShapeEnv will ignore any further guards generated on it and
only emit a warning which may lead to accuracy problems.
"""
self.frozen = True
[docs] @record_shapeenv_event()
def freeze_runtime_asserts(self):
"""Freeze this ShapeEnv to stop adding deferred runtime asserts.
We will error if you try to install a new runtime assert when it is
frozen. This would indicate a lowering violation, or perhaps something
we know statically is already True but we are checking it again in a way
that is not clearly dischargeable.
"""
self.runtime_asserts_frozen = True
def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
if not self._translation_validation_enabled:
return None
srcname = source.name()
if source not in self.source_to_symbol:
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
return self.source_to_symbol[srcname]
def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None:
if self._translation_validation_enabled:
self.validator.add_var(symbol, type)
def _add_target_expr(self, expr) -> None:
if self._translation_validation_enabled:
self.validator.add_target_expr(expr)
def _add_assertion(self, expr) -> None:
if self._translation_validation_enabled:
self.validator.add_assertion(expr)
def _check_translation_validate(self) -> None:
if self._translation_validation_enabled:
self.validator.validate()
@record_shapeenv_event()
def _create_fx_call_function(
self,
op: Callable,
args: Tuple,
) -> Tuple[Optional[torch.fx.Node], bool]:
# Cache this tuple in order to avoid duplicated nodes.
node_key = (op, args)
# Flags whether the returned node was cached or not.
fresh = False
if self._translation_validation_enabled and node_key not in self.fx_node_cache:
# Presence of None in the arguments implies that we should ignore this operation.
if any(a is None for a in args):
# We check if we are not mixing SymNode that should not be ignored
# (fx_node is not None) with those that should (fx_node is None).
assert all(not isinstance(a, torch.fx.Node) for a in args)
return None, fresh
fresh = True
# If translation validation is enabled, all arguments must have its
# own FX node.
assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}"
node = self.fx_node_cache[node_key] = self.graph.call_function(op, args)
self.name_to_node[node.name] = node
return self.fx_node_cache.get(node_key, None), fresh
def _create_fx_placeholder_and_z3var(
self,
symbol: sympy.Symbol,
type: Type,
) -> Optional[torch.fx.Node]:
if not self._translation_validation_enabled:
return None
node_key = (self.graph.placeholder, (symbol,))
# Check if we haven't added this symbol already.
# If so, skip the placeholder creation, as it
# generates invalid Python code.
if node_key not in self.fx_node_cache:
# Add a Z3 variable according to 'type'.
self._add_z3var(symbol, type)
# Create the FX placeholder out of a mangled name.
mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name))
node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name)
self.name_to_node[node.name] = node
# Attach the 'symbol' to the placeholder so that we can retrieve
# the Z3 variable later.
node.meta["symbol"] = symbol
return self.fx_node_cache[node_key]
def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
if self._translation_validation_enabled and node is not None:
self.name_to_node.pop(node.name)
self.graph.erase_node(node)
def _add_fx_node_metadata(self, node: torch.fx.Node) -> None:
from torch._dynamo.utils import get_current_node
if self.should_record_events:
node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index()
node.meta[CURRENT_NODE_KEY] = get_current_node()
def _suppress_guards_tls(self):
return getattr(TLS, "suppress_guards", False)
@record_shapeenv_event()
def _suppress_guards_enter(self):
TLS.suppress_guards = True
@record_shapeenv_event()
def _suppress_guards_exit(self):
TLS.suppress_guards = False
[docs] @contextmanager
def suppress_guards(self):
"""Context manager to ignore all guards generated inside"""
self._suppress_guards_enter()
try:
yield
finally:
self._suppress_guards_exit()
def _get_key(self):
"""
Defines the current "state" of the guards we've accumulated in this ShapeEnv.
Determines when we need to invalidate our cache
"""
return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts, len(self.unbacked_var_to_val))
def _update_version_counter(self):
# The shape environment is queried orders of magnitude more often than
# it is changed, so we summarise the cache key into a linearly
# increasing version counter which is cheaper to check in _lru_cache
# Only update version counter if the state actually changed
cur_key = self._get_key()
if self._prev_cache_key != cur_key:
self._prev_cache_key = cur_key
self._version_counter += 1
def _produce_dyn_sizes(self,
ex_size: Sequence[int],
source: Source,
symbolic_context: SymbolicContext
) -> List[sympy.Expr]:
return self._produce_dyn_sizes_from_int_tuple(tuple(ex_size), source, symbolic_context)
def _produce_dyn_sizes_from_int_tuple(self,
tensor_size: Tuple[int],
source: Source,
symbolic_context: SymbolicContext,
) -> List[sympy.Expr]:
assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}"
from torch._dynamo.source import TensorPropertySource, TensorProperty
_assert_symbol_context(symbolic_context)
dynamic_dims = symbolic_context.dynamic_sizes
constraint_dims = symbolic_context.constraint_sizes
size = []
for i, val in enumerate(tensor_size):
size.append(self.create_symbol(
val,
TensorPropertySource(source, TensorProperty.SIZE, i),
dynamic_dims[i],
constraint_dims[i],
symbolic_context=symbolic_context
))
return size
[docs] def create_symbolic_sizes_strides_storage_offset(
self,
ex: torch.Tensor,
source: Source,
*,
symbolic_context: Optional[SymbolicContext] = None,
):
"""
Returns a list of symbolic sizes and strides for the given tensor.
We try our best to express stride in terms of the sizes, so as to not
introduce new symbolic variables.
"""
ex_size = tuple(self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
ex_stride = tuple(self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
ex_storage_offset = self._maybe_specialize_sym_int_with_hint(ex.storage_offset())
return self._create_symbolic_sizes_strides_storage_offset(
ex_size,
ex_stride,
ex_storage_offset,
[_is_dim_dynamic(ex, i) for i in range(ex.dim())],
source,
symbolic_context=symbolic_context,
)
# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
# We create symbols in shape_env using the backed hints behind SymInt.
# Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
# produce_guards will trigger specializations on the outer stuff
# Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
#
# It's probably good for now but it's important to note that this approach has implications for
# the original shape_env when checking guards in different order.
# Example:
# ---------
# Consider a function "opt_f" as shown below:
# @torch.compile()
# def opt_f(x: bool, y: Tensor):
# if x == True:
# return y + torch.randn([4])
# else:
# return y
# Depending on the sequence of calls, we might install two different sets of guards:
# 1. opt_f(False, y):
# - "x == False" (always works for any size y)
# 2. opt_f(True, y):
# - Triggers recompilation and results in guards like:
# - "x == True and y.size(0) == 4"
# - (or "y.size(0) == 4 and x == True")
# The order of checking the guards matters. In this specific example:
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
# we may have an unnessary shape speciliazation for y.
def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> int:
assert isinstance(maybe_sym, (int, torch.SymInt))
if is_symbolic(maybe_sym):
assert maybe_sym.node.shape_env is not self, \
"expect the symbol is created from an shape env other than current one."
return maybe_sym.node.require_hint()
return maybe_sym
@record_shapeenv_event()
def _create_symbolic_sizes_strides_storage_offset(
self,
ex_size: Sequence[int],
ex_stride: Sequence[int],
ex_storage_offset: int,
is_dim_dynamic: Sequence[bool],
source: Source,
*,
symbolic_context: Optional[SymbolicContext] = None,
):
dim = len(ex_size)
# Reimplement the legacy behavior
if symbolic_context is None:
constraint_dims = [None] * dim
dynamic_dims = []
for i in range(dim):
# NB: This is encapsulation breaking! Legacy behavior was
# bad.
if is_dim_dynamic[i]:
r = DimDynamic.DYNAMIC
elif self.assume_static_by_default:
r = DimDynamic.STATIC
else:
r = DimDynamic.DUCK
dynamic_dims.append(r)
dynamic_dims = [DimDynamic.DUCK] * dim
# symbolic_context is None - set one
symbolic_context = StatelessSymbolicContext(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims)
# We got a StatelessSymbolicContext
_assert_symbol_context(symbolic_context)
constraint_dims = symbolic_context.constraint_sizes
dynamic_dims = symbolic_context.dynamic_sizes
# TODO: make this configurable from outside symbolic_context; we made a symbolic_context
# decision here where if all sizes are static, we are going to
# specialize all of the inner strides/offset too. We don't have to
# do this, and arguably we should ALWAYS allow for dynamic offset,
# this is cheap.
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK
assert len(dynamic_dims) == dim, f"{len(dynamic_dims)} != {dim}"
assert len(constraint_dims) == dim
from torch._dynamo.source import TensorPropertySource, TensorProperty
size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context)
stride: List[Optional[sympy.Expr]] = [None] * len(size)
for i, val in enumerate(ex_stride):
if val in (0, 1):
stride[i] = sympy.Integer(val)
while any(x is None for x in stride):
candidates = {
ex_size[i] * ex_stride[i]: size[i] * stride[i]
for i in range(len(size))
if stride[i] is not None and ex_stride[i] >= 0
}
# iterate over unbound strides in sorted order
def _nested_int_aware_sort(tup):
return (
# Order nested ints by their coefficients.
# 1 here to order nested ints after non-nested-ints.
(1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0])
else (0, *tup)
)
val_list = sorted(
[(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
key=_nested_int_aware_sort,
)
for _, i in val_list:
if stride[i] is None and ex_stride[i] in candidates:
stride[i] = candidates[ex_stride[i]]
candidates[ex_size[i] * ex_stride[i]] = size[i] * stride[i]
if any(x is None for x in stride):
# bind the smallest unbound stride to a new variable
val, i = min(
[
(ex_stride[i], i)
for i in range(len(stride))
if stride[i] is None
], key=_nested_int_aware_sort
)
stride[i] = self.create_symbol(
val,
TensorPropertySource(source, TensorProperty.STRIDE, i),
dynamic_dim=dynamic_strides_offset,
constraint_dim=None,
symbolic_context=symbolic_context,
)
assert all(x is not None for x in stride)
sym_sizes = [
self.create_symintnode(
sym,
hint=hint,
source=TensorPropertySource(source, TensorProperty.SIZE, i),
)
for i, (sym, hint) in enumerate(zip(size, ex_size))
]
sym_stride = []
for i, stride_expr in enumerate(stride):
# NB: Don't duck size the stride; instead use the expression
# we computed
assert stride_expr is not None
sym_stride.append(self.create_symintnode(
stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i)))
sym_storage_offset = self.create_symintnode(
self.create_symbol(
ex_storage_offset,
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
dynamic_dim=dynamic_strides_offset,
constraint_dim=None,
symbolic_context=symbolic_context
),
hint=ex_storage_offset,
source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET))
return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset
[docs] @record_shapeenv_event()
def create_symintnode(
self,
sym: "sympy.Expr",
*,
hint: Optional[int],
source: Optional[Source] = None,
):
"""Create a SymInt value from a symbolic expression
If you know what the current hint value of the SymInt to be created
is, pass it into hint. Otherwise, pass None and we will make our best
guess
"""
source_name = source.name() if source else None
if self._translation_validation_enabled and source is not None:
# Create a new symbol for this source.
symbol = self._create_symbol_for_source(source)
assert symbol is not None
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
# Add an equality assertion for the newly created symbol and 'sym'.
self._add_assertion(sympy.Eq(symbol, sym))
else:
fx_node = None
if isinstance(sym, sympy.Integer):
if hint is not None:
assert int(sym) == hint
out = int(sym)
else:
out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
return out
[docs] @record_shapeenv_event()
def create_symfloatnode(
self,
sym: "sympy.Expr",
*,
hint: Optional[int],
source: Optional[Source] = None,
):
"""Create a SymFloat value from a symbolic expression"""
source_name = source.name() if source else None
if self._translation_validation_enabled and source is not None:
# Create a new symbol for this source.
symbol = self._create_symbol_for_source(source)
assert symbol is not None
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
# Add an equality assertion for the newly created symbol and 'sym'.
self._add_assertion(sympy.Eq(symbol, sym))
else:
fx_node = None
if isinstance(sym, sympy.Float):
if hint is not None:
assert float(sym) == hint
out = float(sym)
else:
out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node))
return out
[docs] @record_shapeenv_event()
def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim):
"""Create a SymInt wrapping a new unspecified symbol"""
return self.create_symintnode(
self.create_unspecified_symbol(
value,
source=source,
dynamic_dim=dynamic_dim,
),
hint=value,
source=source,
)
[docs] def create_symboolnode(self, sym: "sympy.Expr"):
"""Create a SymBool object from a sympy boolean expression"""
# This function is only being used in serialization, so we do not track it
# for validation.
return SymBool(SymNode(sym, self, bool, None))
def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges):
is_debug = config.extended_debug_create_symbol is not None and str(symbol) in config.extended_debug_create_symbol.split(',')
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
log.info(
"%s %s [%s, %s]%s (%s)%s",
prefix, symbol, vr.lower, vr.upper, maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug
)
[docs] @record_shapeenv_event()
def create_unbacked_symfloat(self):
"""Create a symbolic float without a hint value
"""
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter))
self.counter["create_unbacked_symbol"] += 1
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
vr = self.var_to_range[symbol] = ValueRanges.unknown()
assert vr.is_float
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
self._log_create_unbacked_symbol("create_unbacked_symfloat", symbol, vr)
return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node))
[docs] @record_shapeenv_event()
def create_unbacked_symint(self):
"""Create a symbolic integer without a hint value
"""
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
vr = self.var_to_range[symbol] = self._default_unspecified_value_range()
assert vr.is_int
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr)
return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node))
[docs] def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
"""Check if a sympy symbol matches the naming convention for unbacked symbols
"""
return symbol_is_type(symbol, SymT.UNBACKED_INT)
[docs] @record_shapeenv_event()
def create_unbacked_symbool(self):
"""Create a symbolic boolean without a hint value
"""
symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True)
if not self._ignore_fresh_unbacked_symbols_tls():
self.pending_fresh_unbacked_symbols.append(symbol)
self.counter["create_unbacked_symbol"] += 1
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
vr = self.var_to_range[symbol] = ValueRanges(0, 1)
assert vr.is_int
# Create a new FX placeholder and Z3 variable for 'symbol'.
fx_node = self._create_fx_placeholder_and_z3var(symbol, bool)
self._log_create_unbacked_symbol("create_unbacked_symbool", symbol, vr)
return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node))
[docs] @record_shapeenv_event()
def create_unspecified_symbol(
self,
val: Union[int, SymInt, float, SymFloat],
source: Source,
dynamic_dim: DimDynamic = DimDynamic.DUCK,
constraint_dim: DimConstraint = None, # NB: includes None
) -> "sympy.Expr":
"""Create a symbol with an unspecified value
Compared to standard symbols we do not assume the value is positive,
nor do we specialze on zero or one values.
"""
# 'positive' is None for unspecified symbols, since we can't
# assume that it will be neither positive nor negative.
# We don't want to specialize zero one val for unspecified symbol
# so that we can always get a new symbol despite val.
return self.create_symbol(
val,
source,
dynamic_dim,
constraint_dim,
positive=None,
do_not_specialize_zero_one=True,
symbolic_context=None)
[docs] @record_shapeenv_event()
def create_symbol(
self,
val: int,
source: Source,
dynamic_dim: DimDynamic = DimDynamic.DUCK,
constraint_dim: DimConstraint = None, # NB: includes None
positive: Optional[bool] = True,
do_not_specialize_zero_one: bool = False,
symbolic_context=None,
) -> "sympy.Expr":
"""Create a new symbol which is tracked by this ShapeEnv
"""
# check if constraint_dim is actually static integer
if isinstance(constraint_dim, StrictMinMaxConstraint) and constraint_dim.vr.lower == constraint_dim.vr.upper:
dynamic_dim = DimDynamic.STATIC
if constraint_dim.vr.lower != val:
raise ConstraintViolationError(
f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, "
f"for {source.name()}"
)
if symbolic_context:
symbolic_context.dynamic_sizes[source.idx] = dynamic_dim
symbolic_context.constraint_sizes[source.idx] = None
constraint_dim = None
# see note [Tensor Fakification and Symbol Caching]
source_name = source.name()
if (isinstance(symbolic_context, StatefulSymbolicContext)
and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache):
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {}
if (isinstance(symbolic_context, StatefulSymbolicContext)
and source_name
and (source_name in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)])):
return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name]
if do_not_specialize_zero_one:
specialize_zero_one = False
else:
specialize_zero_one = self.specialize_zero_one
assert isinstance(source, Source), f"{type(source)} {source}"
assert not (positive and val < 0), f"positive set for negative value: {val}"
# It's always sound to allocate a symbol as DYNAMIC. If the user
# constrained the symbol, force the symbolic_context to DYNAMIC, because our
# constraint code will do weird stuff if, e.g., it's duck shaped
if constraint_dim is not None:
dynamic_dim = DimDynamic.DYNAMIC
if dynamic_dim is DimDynamic.STATIC:
out = sympy.Integer(val)
if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out
return out
elif dynamic_dim is DimDynamic.DUCK:
# duck_shape can be used to globally turn off duck shaping, even
# if it was requested
duck = self.duck_shape
elif dynamic_dim is DimDynamic.DYNAMIC:
duck = False
else:
raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")
if val in (0, 1) and specialize_zero_one:
r = self.val_to_var[val]
elif not duck or val not in self.val_to_var:
# If we're not duck shaping, we always create a new symbol
# Even if we're duck shaping, if we haven't seen this particular
# value before, we also create a new symbol
if type(val) is int:
sympy_expr = make_symbol(SymT.SIZE, len(self.var_to_val), positive=positive, integer=True)
else:
sympy_expr = make_symbol(SymT.FLOAT, len(self.var_to_val), positive=positive, real=True)
# We always associate vars to vals
if isinstance(val, int):
self.var_to_val[sympy_expr] = sympy.Integer(val)
elif isinstance(val, float):
self.var_to_val[sympy_expr] = sympy.Float(val)
else:
# Only used for jagged layout nested tensors
self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff())
# Do the appending later, because we always want to populate this
self.var_to_sources[sympy_expr] = []
# Create a Z3 variable for the new symbol.
self._add_z3var(sympy_expr, int)
if duck:
# Make sure to reuse this symbol for subsequent duck shaping
self.val_to_var[val] = sympy_expr
if isinstance(val, int):
if positive:
# Add assertions for the newly created symbols
self._add_assertion(sympy_expr > 1)
# Apply default range, which assumes not zero-one
self.var_to_range[sympy_expr] = self._default_value_range()
else:
self.var_to_range[sympy_expr] = self._default_unspecified_value_range()
# Small performance optimization: if we have a min-max constraint,
# we can proactively narrow to that range
if isinstance(constraint_dim, StrictMinMaxConstraint):
assert not duck
self.var_to_range[sympy_expr] &= constraint_dim.vr
vr = self.var_to_range[sympy_expr]
assert vr.is_int
if val not in vr:
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")
range_str = f"[{vr.lower}, {vr.upper}]"
elif isinstance(val, float):
self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo)
range_str = f"[{vr.lower}, {vr.upper}]"
assert vr.is_float
else:
# Skip var_range logic for SingletonInt
# Only used for jagged layout nested tensors
range_str = ""
r = sympy_expr
is_debug = (
config.extended_debug_create_symbol is not None and
str(sympy_expr) in config.extended_debug_create_symbol.split(',')
)
maybe_more_info = ""
if not is_debug:
maybe_more_info = (
", for more info run with "
f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}"'
)
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
self.log.info(
"create_symbol %s = %s for %s %s%s (%s)%s%s",
sympy_expr, val, source.name(), range_str,
maybe_user_loc, format_frame(fsummary), maybe_more_info, maybe_extra_debug, stack_info=is_debug
)
self.counter["create_symbol"] += 1
else:
# This implements duck-shaping: input sizes that match are assigned
# the same symint
r = self.val_to_var[val]
self.log.debug("create_symbol %s duck sized %s", r, source.name())
if isinstance(r, sympy.Symbol):
r_sources = self.var_to_sources[r]
r_sources.append(source)
if not source.is_ephemeral() and r_sources[0].is_ephemeral():
# prefer non-ephemeral source first since it may be guarded on later
r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0]
# This ensures we get zeros in symbol_guard_counts, which makes
# some queries simpler (since we will accumulate mass on 0 this
# way)
self.symbol_guard_counter[r] = 0
if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = r
return r
[docs] def add_var_to_val(self, expr: sympy.Symbol, val: int):
""" Adds a new symbol to the symbolic environment. """
log.debug("add_var_to_val %s %s", expr, val, stack_info=True)
assert expr not in self.var_to_val, f"{expr} already exists"
self.var_to_val[expr] = sympy.Integer(val)
def _debug_name(self, source):
src_name = source.name()
return self.source_name_to_debug_name.get(src_name, src_name)
def _render_range_for_constraint_violation(self, source, c):
if isinstance(c, StrictMinMaxConstraint):
lower, upper = c.vr.lower, c.vr.upper
default = self._default_value_range()
if lower <= default.lower:
lower = None
if upper >= default.upper:
upper = None
c_render = f"{self._debug_name(source)} = {source.name()} in the specified range"
if lower is not None and upper is not None:
c_render += f" {lower} <= {self._debug_name(source)} <= {upper}"
elif lower is None and upper is not None:
c_render += f" {self._debug_name(source)} <= {upper}"
elif lower is not None and upper is None:
c_render += f" {lower} <= {self._debug_name(source)}"
return c_render
return c.render(source)
[docs] def produce_guards(
self,
placeholders,
sources,
source_ref=lambda n: n.name(),
*,
guards: List[ShapeGuard] = None,
input_contexts: Optional[DimList[SymbolicContext]] = None,
# Encodes user-specified input shape equations of the form s = s' and s = fn(s').
# (See docs on EqualityConstraint for details of the encoding.)
equalities_inputs: Optional[EqualityConstraint] = None,
_simplified=False,
_disable_forced_specializations=False,
# Indicates if we should produce guards for known static values.
ignore_static=True,
) -> List[str]:
"""
Generates a list of guards strings which, when evaluated in a context that
defines tensors for all the sources, returns True or False depending
on if the guards in the list evaluated to True or not. Primarily used by Dynamo,
but this is also helpful for manual testing of guards (see
evaluate_guards_for_args)
For convenience in testing, a source is allowed to be a str,
in which case we will assume it is a LocalSource
simplified lets you omit duck sizing, equality and 0/1 guards.
This is useful for testing when you don't care about the boilerplate
guards, and it may be helpful for user output too (be careful though;
some equality guards are nontrivial! It would be nice to get simplified
output to print them too). It's private because it's not
intended for normal use
"""
self.log.info("produce_guards")
# Check if we get to the same ShapeEnv state by replaying the recorded events.
# This will create a new ShapeEnv instance, and call all recorded function
# calls on this new instance. Finally, it will check whether this new instance
# has equal state.
#
# It's important that we do it in the begining of this function, since it modifies
# self.dim_constraints through its execution. Changes that happen in this method
# aren't interesting, since this is the function call we wish to reproduce at the
# end. If we wish to simply reproduce ShapeEnv instances even after this call,
# this method should also be recorded.
if self.check_recorded_events:
shape_env = replay_shape_env_events(self.events)
self.check_equal(shape_env)
assert len(placeholders) == len(sources), f"len({placeholders}) != len({sources})"
Tensorlike = (torch.Tensor, FakeTensorMeta)
def _create_no_constraints_context(t):
return StatelessSymbolicContext(
# Ignored; only the constraints part is relevant below.
dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(),
constraint_sizes=[None] * t.dim()
)
# Expand optional inputs, or verify invariants are upheld
if input_contexts is None:
input_contexts = [
_create_no_constraints_context(t) if isinstance(t, Tensorlike)
else None for t in placeholders
]
else:
assert len(input_contexts) == len(placeholders)
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
if isinstance(t, Tensorlike):
if context is None:
input_contexts[i] = _create_no_constraints_context(t)
else:
assert isinstance(t, (SymInt, int, SymFloat, float))
assert not isinstance(context, list)
# It took a lot of sweat to figure out the algorithm here. Let's
# explain how it works.
#
# The ShapeEnv lifecycle looks something like this:
#
# - For each input, you either generate a fresh Sympy symbol (s0) to
# represent its value (a binding site), or you reuse some
# preexisting symbol or expression, skipping the symbol allocation
# (e.g., duck sizing to a preexisting symbol, or expressing a
# stride as a multiplication of a separate stride and size.)
# Naively, you might expect to bind a fresh Sympy symbol for
# every input, but this is fairly wasteful as most of these
# symbols immediately simplify away, and if you don't eagerly
# specialize, e.g., 0/1 symbols, you end up with very complicated
# expressions that are not optimizable in practice.
#
# - You perform some compute on these symbols, occasionally
# introducing guards on boolean expressions on these symbols.
# In particular, whenever we guard on equality (_maybe_guard_rel),
# we can simplify shapes; e.g., when s0 == s1 * 2, we can now
# replace all occurrences of s0 with s1 * 2. Sometimes, a
# boolean expression evaluation doesn't introduce a guard, as
# the guard is already entailed by the simplifications we have
# applied.
#
# - In the end, you have a bunch of replacements (saying how to
# simplify shapes) and a bunch of guards (all the equality guards
# are trivial, because they're covered by the replacements).
#
# From the ShapeEnv, we must generate a Python expression that, when
# evaluated on a set of inputs, tells us whether or not these boolean
# expressions would have evaluated in the same way. However,
# we cannot easily compute this, as we elide recording boolean
# expressions when we think they are vacuously true. Thus, we seek
# an approximation: we must generate an expression, if true, would have
# produced an "equivalent" ShapeEnv, which would answer guard
# expressions in the same way.
#
# Our notion of equivalence is a bit subtle. For example, consider
# the ShapeEnv created from an input of size (5, 4) versus (4, 4)
# (no other guards.) Duck sizing would generate (s0, s1) in the first
# case but (s0, s0) in the second. We do NOT assume that size
# variables are disjoint; so in fact a graph that assumes the input
# could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not
# vice versa. However, consider an analogous case (1,) versus (2,).
# Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT
# subsume the (1,) graph because we assume that any size variables
# is NOT 0/1 (and make simplifications according to this; e.g., if
# we queried s0 == 0, we would immediately return False without
# returning a guard.)
#
# So, it is perhaps easier to flip things on their head: the guard
# expressions we generate here say what simplifications are valid,
# and what are not. Below, we explain each of the guard expressions
# we generate
# TODO: Make this more efficient by binding all the size/stride/offsets
# to locals before performing tests on them.
from torch._dynamo.source import TensorPropertySource, TensorProperty
# Actual codegen must be delayed as we don't necessarily know what
# the symbol mapping is
input_guards = []
symbol_to_source = collections.defaultdict(list)
symbol_to_constraints = collections.defaultdict(set)
constraint_violations : List[Tuple[bool, Callable[[], str]]] = []
def record_constraint_violation(warn_only, debug_name, msg, hint=None):
constraint_violations.append(
(warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg)
)
def is_dim(src):
return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE
if equalities_inputs:
source_index = {}
for i, src in enumerate(sources):
source_index[src.name()] = i
def get_expression(tensor_dim_src):
fake = placeholders[source_index[tensor_dim_src.base.name()]]
symint = fake.shape[tensor_dim_src.idx]
if isinstance(symint, torch.SymInt):
return symint.node.expr
else:
assert type(symint) is int, f"Expected int, got {type(symint)}"
return symint
for src1, src2 in equalities_inputs.source_pairs:
expr1, expr2 = get_expression(src1), get_expression(src2)
# Check whether given input shape values satisfy a specified equation s = s'.
# - Raise when the equation was violated by the given input shape values.
# - Otherwise issue a guard to constrain them.
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2))
if not concrete_val:
raise ConstraintViolationError(
f"{src1.name()} = {expr1.xreplace(self.var_to_val)}"
" is not equal to "
f"{src2.name()} = {expr2.xreplace(self.var_to_val)}"
)
for src, root, fn in equalities_inputs.derived_equalities:
expr1 = get_expression(src)
# recall that root is either a phantom symbol or an input source
expr2, debug_name = (
(root, self.var_to_sources[root][0].name()) if isinstance(root, sympy.Symbol)
else (get_expression(root), self._debug_name(root))
)
expr2_ = fn(expr2)
# Check whether given input shape values satisfy a specified equation s = fn(s').
# - Raise when the equation was violated by the given input shape values.
# - Otherwise issue a guard to constrain them.
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_))
if not concrete_val:
raise ConstraintViolationError(
f"Expected input {src.name()} to be equal to "
f"{fn(sympy.Symbol(debug_name))}, "
f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, "
f"but got {expr1.xreplace(self.var_to_val)}"
)
for phantom_symbol in equalities_inputs.phantom_symbols:
# we created additional phantom symbols that are not input shape dimensions
symbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol])
# How do we know what the value of s0 is? Fresh variables can only be
# bound by inputs, so there MUST be some other input which binds the
# variable. If there is no such input, this is an error in our
# system. We record where all symbols come from, to help you diagnose
# why those symbols didn't occur.
#
# In fact, generally speaking it is only possible for the "outermost"
# user of a ShapeEnv to evaluate the guards, because some inputs may
# not be available to inner levels. For example, Dynamo can guard on
# tensors that never actually become graph arguments (they are
# pruned). In this case, only Dynamo knows about these arguments.
def track_symint(source, val, constraint=None):
log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint)
assert not isinstance(val, SymInt) or is_symbolic(val)
if isinstance(val, SymInt) and val.node.maybe_as_int() is not None:
val = val.node.maybe_as_int()
if isinstance(val, SymInt):
s = val.node.expr
if isinstance(s, sympy.Symbol):
symbol_to_source[s].append(source)
if constraint is not None:
symbol_to_constraints[s].add(constraint)
else:
constraint_violated = False
if isinstance(constraint, StrictMinMaxConstraint):
# try inferring the ranges of the expr s
sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols}
if any(vr is None for vr in sym_vrs.values()):
# some of the free symbols in s don't have ranges
constraint_violated = True
elif isinstance(constraint, RelaxedUnspecConstraint):
if s.is_number:
i = int(s)
# Don't complain about 0/1 specialization, we
# expect to have to compile in this case anyway
if i not in (0, 1):
constraint_violated = True
if constraint_violated:
def hint(s):
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s)
return f"{sexpr}."
var_with_range = self._render_range_for_constraint_violation(source, constraint)
msg = (
f"Not all values of {var_with_range} are valid because "
f"{self._debug_name(source)} was inferred to be equal to "
)
record_constraint_violation(
constraint.warn_only,
self._debug_name(source),
msg,
hint=functools.partial(hint, s),
)
input_guards.append((source, s))
else:
s = sympy.Integer(val)
input_guards.append((source, s))
constraint_violated = False
if isinstance(constraint, StrictMinMaxConstraint):
if not (s == constraint.vr.lower == constraint.vr.upper): # allow static constraints
constraint_violated = True
elif isinstance(constraint, RelaxedUnspecConstraint):
# Don't complain about 0/1 specialization, we
# expect to have to compile in this case anyway
if val not in (0, 1):
constraint_violated = True
if constraint_violated:
var_with_range = self._render_range_for_constraint_violation(source, constraint)
msg = (
f"Not all values of {var_with_range} are valid because "
f"{self._debug_name(source)} was inferred to be a constant ({val})."
)
record_constraint_violation(constraint.warn_only, self._debug_name(source), msg)
def track_symfloat(source, val):
log.debug("track_symfloat %s %s", LazyString(source.name), val)
assert not isinstance(val, SymFloat) or is_symbolic(val)
if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None:
val = val.node.maybe_as_float()
if isinstance(val, SymFloat):
s = val.node.expr
if isinstance(s, sympy.Symbol):
symbol_to_source[s].append(source)
input_guards.append((source, s))
else:
s = sympy.Float(val)
input_guards.append((source, s))
for t, source, context in zip(placeholders, sources, input_contexts):
if isinstance(source, str):
from torch._dynamo.source import LocalSource
source = LocalSource(source)
assert isinstance(source, Source)
if t is None:
continue
if isinstance(t, (SymInt, int)):
track_symint(source, t)
continue
elif isinstance(t, (SymFloat, float)):
track_symfloat(source, t)
continue
assert isinstance(t, Tensorlike)
if is_traceable_wrapper_subclass(t):
from torch._dynamo.source import AttrSource
assert isinstance(context, SubclassSymbolicContext)
# For subclasses, we need to track symints on BOTH the outer
# and inner tensors.
sources_tensors_constraints = [
(source, t, context.constraint_sizes)
]
attrs, _ = t.__tensor_flatten__()
for attr in attrs:
inner_t = getattr(t, attr)
inner_context = context.inner_contexts[attr]
sources_tensors_constraints.append((
AttrSource(source, attr),
inner_t,
inner_context.constraint_sizes
))
else:
sources_tensors_constraints = [(source, t, context.constraint_sizes)]
for src, curr_t, constraint in sources_tensors_constraints:
if is_sparse_any(curr_t):
for i, ss in enumerate(curr_t.size()):
property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
track_symint(property_source, ss, constraint[i])
else:
for i, ss in enumerate(curr_t.size()):
property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
track_symint(property_source, ss, constraint[i])
for i, ss in enumerate(curr_t.stride()):
track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss)
track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
# 1. Every input must equal the final simplified symbolic expression
# stored on the placeholder. Given a placeholder (s0*2, s1),
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
# This does a lot of work: it covers duck sizing and equality guards.
exprs = []
self.dim_constraints = DimConstraints(
symbol_to_source,
self.var_to_val,
set(symbol_to_constraints.keys()),
self.source_name_to_debug_name,
self._allow_complex_guards_as_runtime_asserts,
)
if not _simplified:
for source, expr in input_guards:
if self._translation_validation_enabled:
# Ignore sources that were not turned into SymInts.
srcname = source.name()
if srcname in self.source_to_symbol:
self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname], expr))
# Small optimization
if (
isinstance(expr, sympy.Symbol) and
symbol_to_source.get(expr) and
source == symbol_to_source[expr][0]
):
continue
# This logic excludes static values found on tensors from guarding, because
# dynamo's check_tensor_fn does that (see guards.cpp).
# However, for non tensor sources, we still need to guard here.
if ignore_static and isinstance(source, TensorPropertySource):
if expr.is_number:
self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}")
continue
if is_dim(source):
self.dim_constraints.add_equality(source, expr)
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
exprs.append(f"{source_ref(source)} == {sexpr}")
if (
isinstance(source, TensorPropertySource)
and source.prop is TensorProperty.SIZE
and equalities_inputs
and len(expr.free_symbols) == 1
):
symbol = next(iter(expr.free_symbols))
if (
isinstance(expr, sympy.Symbol) and
expr in symbol_to_constraints and
not equalities_inputs.is_equal(source, symbol_to_source[expr][0])
):
msg = (
f"The values of {self._debug_name(source)} = {source.name()} and "
f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} "
"must always be equal."
)
record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
if (
not isinstance(expr, sympy.Symbol) and
symbol in symbol_to_constraints and
not equalities_inputs.is_derived(source, symbol_to_source[symbol][0], lambda x: expr.xreplace({symbol: x}))
):
src = symbol_to_source[symbol][0]
msg = (
f"The values of {self._debug_name(source)} = {source.name()} must always be related to "
f"the values of {self._debug_name(src)} = {src.name()} by "
f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}."
)
record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
# NB: Not necessary to report constraint violations here:
# constraints are guaranteed to be on symbols (we've already
# caught constants and non-atomic expressions), so we only
# have relational constraints, but we don't support those
# at the moment
# 2. Every guard must evaluate to True (but remember many guards
# like s0 == s1*2 because trivial due to simplification)
issued = set()
def issue_guard(guard: ShapeGuard) -> None:
expr = self.simplify(guard.expr)
# Avoid re-issueing the same guard.
if expr in issued:
return
issued.add(expr)
try:
is_trivial = False
if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]):
is_trivial = self.dim_constraints.add(expr)
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
exprs.append(guard_expr)
self._add_target_expr(expr)
# A non-relational constraint on a single sizevar can violate
# a constraint
if not is_trivial and len(expr.free_symbols) == 1:
symbol = next(iter(expr.free_symbols))
source = symbol_to_source[symbol][0]
constraints = symbol_to_constraints[symbol]
for c in constraints:
if isinstance(c, StrictMinMaxConstraint):
if not _disable_forced_specializations:
var_with_range = self._render_range_for_constraint_violation(source, c)
msg = (
f"Not all values of {var_with_range} "
f"satisfy the generated guard {guard_expr}."
)
record_constraint_violation(c.warn_only, self._debug_name(source), msg)
elif isinstance(c, RelaxedUnspecConstraint):
# This is fine, we allow guards here as long as it
# didn't constrain it to one value (we don't
# actually know this; this depends on our
# ValueRanges reasoning capability)
pass
else:
raise AssertionError(f"unrecognized constraint {c}")
except Exception:
self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
raise
# First, issue all guards.
# This removes all the checks that follow from bounds
# We could simply emit those and also the bounds 2 <= size when necessary
for guard in (guards if guards is not None else self.guards):
if self._maybe_evaluate_static(guard.expr, axioms=()) is not None:
continue
issue_guard(guard)
# 3. Every symbol must be within its value range (this handles 0/1
# specialization too).
for symbol, sources in symbol_to_source.items():
r = self.var_to_range.get(symbol)
if r is None:
if symbol not in self.var_to_range:
continue
r = self.var_to_range[symbol]
assert sources
bounds = []
if r.lower != -sympy.oo:
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
# Only print lower bound in simplified mode if it is not the
# default
if not _simplified or r.lower != self._default_value_range().lower:
bounds.append(str(r.lower))
bounds.append(source_ref(sources[0]))
# NB: This looks like an off-by-one error but it's not: the
# upper bound may be sys.maxsize - 1 because we intentionally
# exclude sys.maxsize from our bounds to deal with direct
# == INT_MAX guards, but it's still dumb to actually test it.
# Note that you can be off by a pretty large constant and it
# won't matter because sizes in practice will be no where near
# the 64-bit limit.
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
if any(is_dim(source) for source in sources):
self.dim_constraints.add(sympy.Le(symbol, r.upper))
# nontrivial upper bound is always interesting
bounds.append(str(r.upper))
if len(bounds) > 1:
exprs.append(" <= ".join(bounds))
# Check constraints
constraints = symbol_to_constraints[symbol]
for c in constraints:
if isinstance(c, StrictMinMaxConstraint):
# NB: By default, we have a restrictive range
# 2 <= s0 <= sys.maxsize - 1. But export users generally
# expect to be able to specify nice ranges like [0, oo]
if not (c.vr & self._default_value_range()).issubset(r):
source = sources[0]
expr = sympy.And(sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper))
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
var_with_range = self._render_range_for_constraint_violation(source, c)
msg = (
f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}"
)
record_constraint_violation(
c.warn_only,
self._debug_name(source),
msg,
)
# We NaN specialize, which means similar to 0/1 specialization we
# should assume that the float is NOT nan. This is load bearing
# if you have something like an equality guard, nan will play
# merry hell with the reasoning.
if symbol_is_type(symbol, SymT.FLOAT):
exprs.append(f"not __math_isnan({source_ref(sources[0])})")
if constraint_violations:
warn_msgs = []
error_msgs = []
debug_names = set()
for warn_only, debug_name, msg in constraint_violations:
if warn_only:
msg = f" {len(warn_msgs) + 1}. {msg()}"
warn_msgs.append(msg)
else:
msg = f" - {msg()}"
error_msgs.append(msg)
debug_names.add(debug_name)
if len(error_msgs) > 0:
debug_names = ', '.join(sorted(debug_names))
err = '\n'.join(error_msgs)
raise ConstraintViolationError(
f"Constraints violated ({debug_names})! "
'For more information, run with TORCH_LOGS="+dynamic".\n'
f"{err}"
)
elif len(warn_msgs) > 0:
log.debug("%s Warning only constraints violated", len(warn_msgs))
signpost_event(
"dynamic",
"produce_guards",
{
**self.co_fields,
**self.counter,
"num_guards": len(exprs),
"free_symbols": sum(1 for v in symbol_to_source.values() if v),
# The keys are meaningless from an aggregate perspective, so
# don't include them. Biggest first.
"symbol_guard_counts": sorted(self.symbol_guard_counter.values(), reverse=True),
},
)
if self._translation_validation_enabled:
from torch.fx.experimental.validator import PopulateValidator
# Add all deferred runtime assertions; these are not technically
# handled by produce_guards but we need to put them in the target
# set
for ras in self.deferred_runtime_asserts.values():
for ra in ras:
self._add_target_expr(ra.expr)
# Add value range bound guards for all symbols with no trivial bounds.
# Reason: '_maybe_evaluate_static' may eliminate guards based on the
# refined value ranges.
for sym, vr in self.var_to_range.items():
if vr.lower != -sympy.oo:
self._add_target_expr(sympy.Le(vr.lower, sym))
if vr.upper != sympy.oo:
self._add_target_expr(sympy.Le(sym, vr.upper))
# Before validating, populate the input of the validator with the
# built FX graph.
with fx_traceback.preserve_node_meta():
PopulateValidator(self.graph, self.validator).run()
# Only run translation validation when we are not passing custom guards
if guards is None:
self._check_translation_validate()
return exprs
[docs] def produce_guards_expression(
self,
placeholders,
*,
guards: Optional[List[ShapeGuard]] = None,
ignore_static=True
):
"""
Expected to be used with evaluate_guards_expression(). Produces the guards
for the given placeholders and returns a string expression to be evaluated
by evaluate_guards_expression given concrete values for the placeholders.
"""
from torch._dynamo.source import LocalSource
arg_names = [f"t{i}" for i in range(len(placeholders))]
produced_guards = self.produce_guards(
placeholders,
[LocalSource(a) for a in arg_names],
guards=guards,
ignore_static=ignore_static,
)
if produced_guards:
return " and ".join(produced_guards)
return None
[docs] def evaluate_guards_expression(self, code, args):
"""
Expected to be used with produce_guards_expression(). Evaluates an expression
generated by produce_guards_expression for the given concrete args.
"""
arg_names = [f"t{i}" for i in range(len(args))]
return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
[docs] def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
"""Generate guards for a graph's placeholder values and evaluate the guards with args
"""
code = self.produce_guards_expression(placeholders, ignore_static=ignore_static)
if code:
return self.evaluate_guards_expression(code, args)
return True
[docs] def get_pruned_guards(self, symints):
"""
Get a list of guards, but pruned so it only provides guards that
reference symints from the passed in input
"""
symints = {s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)}
guards = []
for g in self.guards:
if all(s in symints for s in g.expr.free_symbols):
guards.append(g)
return guards
[docs] def bind_symbols(self, placeholders, args):
"""
Given a paired list of placeholders (fake tensors with
symbolic sizes) and concrete arguments (regular tensors
with real sizes), returns a dictionary mapping each
symbol to its real value. So for example, if you
have a placeholder with size (s0, s1), binding
(2, 4) to it will give you {s0: 2, s1: 4}. This is
not guaranteed to bind ALL symbols in the ShapeEnv;
we can't bind a symbol if it doesn't occur in any placeholder,
and symbols that already have replacements won't get bindings.
This is a little duplicative with evaluate_guards but
it's different enough that it seemed cleanest to make
another copy. This assumes the guards are already checked,
though if it's cheap we'll check for shenanigans
"""
bindings: Dict[sympy.Symbol, int] = {}
def bind_symint(arg, val):
if isinstance(val, SymInt):
s = val.node.expr
if isinstance(s, sympy.Symbol):
if s in bindings:
assert bindings[s] == arg, f"{bindings[s]} != {arg}"
else:
bindings[s] = arg
elif isinstance(-s, sympy.Symbol):
if -s in bindings:
assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}"
else:
bindings[-s] = -arg
for t, arg in zip(placeholders, args):
if t is None:
continue
if isinstance(t, SymInt):
bind_symint(arg, t)
continue
assert isinstance(t, torch.Tensor)
for i, s in enumerate(t.size()):
bind_symint(arg.size(i), s)
for i, s in enumerate(t.stride()):
bind_symint(arg.stride(i), s)
bind_symint(arg.storage_offset(), t.storage_offset())
return bindings
[docs] def get_nontrivial_guards(self):
"""Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr, axioms=()) is None]
[docs] def format_guards(self, verbose=False):
"""Format this shape env's guard expressions with optional traceback info if verbose"""
def format_tb(tb):
if not verbose:
return ""
return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}"
return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
[docs] def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRanges:
"""Given a sympy expression, computes a ValueRanges bound for what values it can be"""
var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
if size_oblivious:
# Clamp values of size-like variables
for x in self.size_like & var_to_range.keys():
if var_to_range[x] is not None:
var_to_range[x] = ValueRanges(2, sys.maxsize - 1)
assert var_to_range[x].is_int
return bound_sympy(expr, var_to_range)
[docs] @_lru_cache
def get_axioms(self, symbols: Optional[Tuple["sympy.Symbol"]] = None) -> Tuple["sympy.Expr"]:
"""
Given the symbols in an expression, it returns all the runtime asserts that have those symbols
concatenated with all the guards.
If symbols is None, it returns all the runtime asserts (and all the guards)
"""
if symbols is None:
runtime_asserts = (r.expr
for rs in self.deferred_runtime_asserts.values()
for r in rs)
else:
runtime_asserts = (r.expr
for s in symbols if s not in self.var_to_val
for r in self.deferred_runtime_asserts.get(s, ()))
guards = (g.expr for g in self.guards)
return tuple(itertools.chain(guards, runtime_asserts))
[docs] @_lru_cache
def get_implications(self,
e: "sympy.Expr",
compute_hint: bool) -> Tuple[Tuple["sympy.Expr", 'sympy.logic.boolalg.BooleanAtom']]:
""" Given a expression, it returns a list of predicates that follow from it """
equiv = {}
def add_expr(expr):
# Expr and negation
equiv[canonicalize_bool_expr(expr)] = sympy.true
equiv[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false
if isinstance(expr, sympy.Rel):
if isinstance(expr, (sympy.Eq, sympy.Ne)):
# multiplying by -1 ensures that equality is commutative
dual = type(expr)(-expr.lhs, -expr.rhs)
else:
# multiplying by -1 changes the direction of the inequality
dual = type(expr)(-expr.rhs, -expr.lhs)
equiv[canonicalize_bool_expr(dual)] = sympy.true
equiv[canonicalize_bool_expr(sympy.Not(dual))] = sympy.false
if compute_hint:
e = canonicalize_bool_expr(e.xreplace(self.var_to_val))
add_expr(e)
# Other relational expressions this expression implies
if isinstance(e, sympy.Eq):
add_expr(sympy.Le(e.lhs, e.rhs))
add_expr(sympy.Ge(e.lhs, e.rhs))
elif isinstance(e, sympy.Lt):
add_expr(sympy.Le(e.lhs, e.rhs))
add_expr(sympy.Ne(e.lhs, e.rhs))
if e.lhs.is_integer and e.rhs.is_integer:
add_expr(sympy.Le(e.lhs, e.rhs - 1))
elif isinstance(e, sympy.Le):
add_expr(sympy.Lt(e.lhs, e.rhs + 1))
return tuple(equiv.items())
@_lru_cache
def _maybe_evaluate_static(
self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False,
expect_rational=True, size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None,
var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None
) -> "Optional[sympy.Expr]":
"""
Tries to evaluate expr without introducing guards
If unbacked_only == True, then we only do substitutions on
unbacked SymInts (leaving regular hinted integers alone). This could
result in an expression that still contains backed SymInts, which you
could then potentially guard on.
Use compute_hint == True if you are trying to compute a non-binding
hint for the particular hint values of backed SymInts, e.g., if
s0 happens to be 3 this run, compute_hint will subsitute s0 with 3.
"""
# axioms with compute hint NYE
assert not compute_hint or not axioms
if var_to_range is None:
var_ranges = self.var_to_range
else:
var_ranges = dict(var_to_range)
expr = self.simplify(expr)
if compute_hint:
expr = expr.xreplace(self.var_to_val)
expr = canonicalize_bool_expr(expr)
# Pattern matching
symbols = tuple(expr.free_symbols)
if axioms is None:
axioms = self.get_axioms(symbols)
subst = {}
for e in axioms:
subst.update(dict(self.get_implications(e, compute_hint=compute_hint)))
expr = expr.xreplace(subst)
# Simplify making use of value range lower bound
new_shape_env = {}
new_range_env = {}
for idx, k in enumerate(symbols):
if isinstance(self.var_to_val.get(k, None), SingletonInt):
# Skip var_ranges logic for SingletonInt which is only used
# for jagged layout NestedTensors today
continue
try:
vr = var_ranges[k]
except KeyError:
log.warning("%s is not in var_ranges, defaulting to unknown range.", k)
vr = self._default_unspecified_value_range()
if size_oblivious and k in self.size_like:
lower = max(2, vr.lower)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems.
if lower <= vr.upper:
vr = ValueRanges(lower, vr.upper)
else:
lower = vr.lower
# Don't do anything if we don't have a nontrivial lower bound
# Also don't do anything if we asked only to simplify unbacked
# SymInt
if (
lower < (-sys.maxsize - 1) // 2 or
(unbacked_only and k in self.var_to_val) or
not vr.is_int
):
new_range_env[k] = vr
continue
# The goal is to take our symbols which have various lower bounds
# and reallocate them into new symbols which are exactly positive;
# e.g., if we have s0 in [2, inf], we want to turn it into ess0 in
# [1, inf], where s0 = ess0 + 1. This gives the most information
# to sympy for subsequent simplifications.
#
# Positive means >= 1
# Positive - 1 means >= 0
# Positive + lower - 1 means >= lower
# The new symbol 's' is "too low", so when we substitute it in
# we have to increase it by offset (and conversely, the new
# variables have to have their value range bounds adjusted as
# well)
s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True)
# Note:
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
# Sympy might give unexepected results when comparing an integer with a non-integer
# Therefore, we cast offset to int here.
# For example:
# shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
# expr = sympy.Eq(shape_0 - 1/3, 4)
# expr.xreplace({}) # False
offset = int(lower - 1)
new_shape_env[k] = s + offset
new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
def replace(expr, repl):
return expr.xreplace(repl)
try:
new_expr = replace(expr, new_shape_env)
except RecursionError:
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
self.counter["sympy_recursion_error"] += 1
return None
new_expr = safe_expand(new_expr)
if new_expr.is_number:
return new_expr
# This is bad to do, the replacement with division leaves us with
# rationals when atom.args[0] is addition, e.g., sympy will happily
# turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication!
"""
floor_div_replace = {}
for atom in new_expr.atoms(FloorDiv):
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
# TODO: when unbacked_only, can sometimes early return even when there
# are still free symbols
if new_expr.is_number:
return new_expr
"""
# Check if the range can solve it statically
out = bound_sympy(new_expr, new_range_env)
if out.is_singleton():
return out.lower
return new_expr if unbacked_only else None
[docs] @_lru_cache
def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
"""Apply symbol replacements to any symbols in the given expression
"""
replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols}
return safe_expand(expr.xreplace(replacements))
@_lru_cache
def _update_divisible(self):
new_divisible = set()
for k in self.divisible:
res = self.replace(k)
if not res.is_number:
new_divisible.add(k)
self.divisible = new_divisible
self._update_version_counter()
[docs] @_lru_cache
def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
"""Use known constraints and replacements to simplify the given expr
"""
expr = self.replace(expr)
# TODO it would seem that this pass is not necessary given the
# below replacement of // with /, but for nested FloorDivs
# the non-recursive replacement doesn't work, and
# recursive makes it hard to look up divisibility,
# because existing divisibility info has FloorDiv in it, not /
# for now just do a separate pass to catch common nested case
if expr.has(FloorDiv):
self._update_divisible()
div_replacements = {}
for atom in expr.atoms(FloorDiv):
base, divisor = atom.args
if isinstance(divisor, FloorDiv):
base1, divisor1 = divisor.args
if self.replace(Mod(base, divisor)) in self.divisible and \
base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible:
div_replacements[atom] = divisor1
expr = expr.xreplace(div_replacements)
expr = safe_expand(expr)
if expr.has(FloorDiv):
div_replacements = {}
pows = expr.atoms(sympy.Pow)
rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))
for fd in expr.atoms(FloorDiv):
base, divisor = fd.args
if self.replace(Mod(base, divisor)) in self.divisible:
div_replacements[fd] = CleanDiv(base, divisor)
new_expr = expr.xreplace(div_replacements)
new_expr = safe_expand(new_expr)
new_pows = new_expr.atoms(sympy.Pow)
new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer))
# divisions simplified away
if new_pows.issubset(pows) and new_rationals.issubset(rationals):
expr = new_expr
return expr
[docs] @lru_cache(256)
def size_hint(self, expr: "sympy.Expr", *, allow_none=False):
"""
Gets a size hint for a given expression from the underlying shapes we had.
Does not introduce a guard, so only use this when you can guarantee that
your code is still valid for arbitrary shapes (such as optimization decisions)
"""
result_expr = safe_expand(expr).xreplace(self.var_to_val)
if not result_expr.is_number:
from torch.utils._sympy.singleton_int import SingletonInt
if isinstance(result_expr, SingletonInt):
return None
r = self._maybe_evaluate_static(result_expr, compute_hint=True)
if r is not None:
return r
if allow_none:
return None
if self.unbacked_var_to_val:
unsound_expr = result_expr.xreplace(self.unbacked_var_to_val)
if not unsound_expr.free_symbols:
log.warning("propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr)
trace_structured(
"propagate_real_tensors",
metadata_fn=lambda: {
"expr": repr(expr),
"result": repr(unsound_expr),
"stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()),
},
)
self.defer_runtime_assert(
sympy.Eq(result_expr, unsound_expr),
f"propagate_real_tensors: {result_expr} == {unsound_expr}"
)
return unsound_expr
raise self._make_data_dependent_error(result_expr, expr)
return result_expr
# NB: keep in sync with size_hint
@lru_cache(256)
def has_hint(self, expr: "sympy.Expr"):
result_expr = safe_expand(expr).xreplace(self.var_to_val)
return result_expr.is_number or self._maybe_evaluate_static(result_expr) is not None
def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None):
# TODO: in a Dynamo context, having user code, and having the
# name of the local, will be much better
size_like_symbols = []
for s in expr.free_symbols:
stacktrace = ''.join(self.var_to_stack[s].format())
self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
if s in self.size_like:
size_like_symbols.append(s)
size_oblivious_result_msg = ""
if size_oblivious_result is not None:
size_oblivious_result_msg = (
f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n"
"Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n"
)
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(True)
if expr.is_integer:
msg = "Could not extract specialized integer from data-dependent expression"
else:
msg = "Could not guard on data-dependent expression"
return GuardOnDataDependentSymNode(
f"{msg} {expr} (unhinted: {unhinted_expr}). "
f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n"
f"{size_oblivious_result_msg}"
"Potential framework code culprit (scroll up for full backtrace):\n"
f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n"
'For more information, run with TORCH_LOGS="dynamic"\n'
"For extended logs when we create symbols, also add "
f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n"
"If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
"For more debugging help, see "
"https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" +
maybe_extra_debug
# TODO: Help text about how to use our runtime tests to fix this
# problem
)
def _update_var_to_range(self, symbol, vr):
lower, upper = vr.lower, vr.upper
# If we have a size-like unbacked SymInt, refuse to refine the range to be
# less than two. This is because when we intersect this range
# with [2, inf] for size oblivious tests, the range would be
# unsatisfiable. In other words, once you have a size-like
# unbacked SymInt, we can never learn that it is exactly zero or one,
# because we would now give inconsistent results for all size
# oblivous tests!
if upper < 2 and symbol in self.size_like:
upper = 2
# Updates the range and the guards corresponding to each bound of the symbol.
if symbol not in self.var_to_range:
r = ValueRanges(lower, upper)
self.log.debug("_update_var_to_range %s = %s (new)", symbol, r)
self.var_to_range[symbol] = r
else:
old = self.var_to_range[symbol]
new = old & ValueRanges(lower, upper)
if new != old:
self.var_to_range[symbol] = new
self.log.debug("_update_var_to_range %s = %s (update)", symbol, new)
def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None:
"""
Adds or updates a replacement for a symbol.
Use this instead of `self.replacements[a] = tgt`.
"""
if tgt == self.replacements.get(a, None):
return
# Precondition: a == tgt
assert isinstance(a, sympy.Symbol)
if self._allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt):
return # continuing leads to placeholder shapes having complex expressions that we can't resolve
# Handles nested tensor symbolic variables which don't have
# var_to_range bounds
tgt_bound = None
if a in self.var_to_range:
src_bound = self.var_to_range[a]
# If you have x in [2, maxint], then 2*x in [4, 2*maxint].
# But we don't really care that the max bound says we can
# go beyond the maximum integer size, because we aren't
# using bigints anyway. Arguably, ValueRanges should know
# to do this truncation automaticaly (to avoid doing
# bigint compute in range analysis), but right now it doesn't
# so we need to get rid of some unnecessary precision.
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
def issubset(x, y):
if x.is_int and y.is_int:
return (x & int_range).issubset(y & int_range)
else:
return x.issubset(y)
# First, refine the value range of a based on the computed value range
# of tgt. This is always OK to do, even if we decide not to do the
# substitution in the end. This might be a no-op, if a already has
# a tighter bound
tgt_bound = self.bound_sympy(tgt)
self._update_var_to_range(a, tgt_bound)
# Next, check if we can update the range of free symbols in tgt
# based on the range in a. But only do it if:
# - the source bound non-trivially improves over what we get out of
# the existing bounds.
# - the replacement is univariate and we can invert the tgt expression
if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1:
b = next(iter(tgt.free_symbols))
# Try to invert the equality
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
if r is not None:
self.log.debug("set_replacement: solve for %s in %s == %s gives %s", b, a, tgt, r)
# The solution here can be non-integral, for example, if
# we have s0 = 2*s1, then s1 = s0/2. What we would like
# to do is calculated the bounds in arbitrary precision,
# and then requantize the bound to integers when we are
# done.
rat_b_bound = self.bound_sympy(r[1])
b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper))
self._update_var_to_range(b, b_bound)
tgt_bound = self.bound_sympy(tgt)
assert issubset(tgt_bound, src_bound)
# TODO: Should we propagate size-like-ness?
#
# Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
# to become size-like.
#
# Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T
# propagate in this case, because what if u0 == 0, then u1 is negative
# and clearly isn't a size. So, at minimum, any f(x) whose value
# range isn't [0, inf] given x in [0, inf] cannot propagate
# size-like-ness. But there are many situations where you could
# imagine u1 is going to be size-like and actually you just didn't
# have a refined enough value range on u0. Since even innocuous
# looking arithmetic operations can destroy size-like-ness, it's
# best to not propagate it at all and force the user to annotate it
# as necessary.
#
# Compromise: we preserve size-like-ness only for exact equality
# and nothing else.
if a in self.size_like and isinstance(tgt, sympy.Symbol):
self.size_like.add(tgt)
elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
self.size_like.add(a)
# Now, decide if we will do the substitution.
#
# - If the source has a non-trivial range, only substitute if
# we preserve this range. Note that we may have propagated
# the src_range to free variables in tgt when tgt is univariate
# and we could find an inverse, which helps us achieve this.
# This ensures we never "forget" about user defined ranges,
# even if they end up being defined on composite formulas
# like s0 + s1.
#
# - If the variable is unbacked, only substitute if the substitution
# would preserve the bounds also under size-like-ness conditions.
if not issubset(tgt_bound, src_bound):
self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound)
return
elif a in self.size_like:
tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
src_bound_so = self.bound_sympy(a, size_oblivious=True)
if not issubset(tgt_bound_so, src_bound_so):
self.log.debug("skipped set_replacement %s = %s (%s) "
"[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
return
if isinstance(tgt, (sympy.Integer, sympy.Float)):
# specializing to a constant, which is likely unexpected (unless
# you specified dynamic=True)
user_tb = TracingContext.extract_stack()
trace_structured(
"symbolic_shape_specialization",
metadata_fn=lambda: {
"symbol": repr(a),
"sources": [s.name() for s in self.var_to_sources.get(a, [])],
"value": repr(tgt),
"reason": msg,
"stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()),
"user_stack": structured.from_traceback(user_tb) if user_tb else None,
}
)
if config.print_specializations:
self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt)
self.log.debug("SPECIALIZATION", stack_info=True)
log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
self.replacements[a] = tgt
self._update_version_counter()
# When specializing 'a == tgt', the equality should be also conveyed to
# Z3, in case an expression uses 'a'.
self._add_target_expr(sympy.Eq(a, tgt))
def _add_divisible(self, expr: "sympy.Expr"):
self.divisible.add(expr)
self._update_version_counter()
@_lru_cache
@record_shapeenv_event()
def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
"""
Implements a DSU-like algorithm to find the variable that represents a
Also handles transitive non-identity replacements.
a: b + c
c: d
"""
if a not in self.replacements:
return a
res = self.replacements[a]
cur_replace = {s: self._find(s) for s in res.free_symbols}
replaced, changed = self.replacements[a]._xreplace(cur_replace)
if changed:
self._set_replacement(a, replaced, "find")
return self.replacements[a]
@lru_cache(256)
def _maybe_guard_rel(self, expr: "sympy.Rel") -> None:
"""
The relational guard is guarded to be true. Use this information to
simplify shapes (i.e. a == b or a % 5 == 0)
"""
assert isinstance(expr, sympy.Rel)
# A good example of what goes wrong if you don't do this is
# python test/functorch/test_aotdispatch.py -k
# test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32
if isinstance(expr, sympy.Ne):
return
free = list(expr.free_symbols)
assert len(free) > 0, f"The expression should not be static by this point: {expr}"
# In case of really gnarly expression, we don't blow up
if len(free) > 5:
return
# Prioritize unbacked symints for solving by ordering them last.
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
# Prefer to simplify out symbols with ephemeral sources.
def _smart_symbol_sort(x):
has_only_ephemeral_sources = (
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
)
size = self.size_hint(x, allow_none=True) or sys.maxsize
name = x.name
# 1 puts ephemeral sourced symbols first when sorting in reverse
return (1 if has_only_ephemeral_sources else 0, size, name)
free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined]
lhs = expr.lhs
rhs = expr.rhs
self._refine_ranges(expr)
# The rest of this stuff is for equality only
if not isinstance(expr, sympy.Eq):
return
if not expr.has(Mod):
try:
floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv))
if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms):
raise NotImplementedError
# Never replace unbacked symbols with other unbacked symbols.
# This is error prone because you can cause references to
# unbacked symbols to time travel backwards. E.g.,
#
# u1 = x.item()
# ... use of u1 ...
# u2 = y.item()
# u3 = z.item()
# torch._check(u1 == u2 + u3)
#
# If you replace u1 with u2 + u3, then the use of u1 now
# references u2 and u3 prior to them actually being bound at
# runtime. It's pretty inconvenient to setup control
# dependencies for substitutions, so ban it entirely.
def trivial_solve(lhs, rhs):
if isinstance(lhs, sympy.Symbol):
if free_unbacked_symbols(lhs) and not free_unbacked_symbols(rhs):
return True
if symbol_is_type(lhs, SymT.FLOAT):
return True
# TODO: Maybe trivial solutions for int should also be
# done?
return False
# short-circuit when no solving is needed
if trivial_solve(lhs, rhs):
self._set_replacement(lhs, self._find(rhs), "trivial_lhs")
elif trivial_solve(rhs, lhs):
self._set_replacement(rhs, self._find(lhs), "trivial_rhs")
else:
r = try_solve(expr, free[0], floordiv_inequality=False)
if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])):
new_var = self._find(r[1])
ok = len(free_unbacked_symbols(new_var)) == 0
if ok:
self._set_replacement(cast(sympy.Symbol, free[0]), new_var, "solve")
except NotImplementedError:
pass
if expr.has(Mod):
mod_expr = next(iter(expr.atoms(Mod)))
try:
r = try_solve(expr, mod_expr, floordiv_inequality=False)
if r is not None and r[1] == 0:
self._add_divisible(mod_expr)
# This is a little bit of extra logic to make things like
# torch.empty(i0, q).view(c, -1, q) work out
p, q = mod_expr.args
if isinstance(q, sympy.Number) and isinstance(p, sympy.Mul) and len(p.args) == 2:
c, i0 = p.args
# Given Mod(c * i0, q) == 0
if (
isinstance(c, sympy.Number) and
isinstance(i0, sympy.Symbol) and
self.is_unbacked_symint(i0)
):
# We have Mod(i0, q / c) == 0, which means we can
# rewrite i0 as (q / gcd(q, c)) * i1
d = q / sympy.gcd(q, c) # TODO: CleanDiv?
i1 = self.create_unbacked_symint().node.expr
# Propagate the value ranges. It doesn't really
# matter if we use truediv or floordiv, because we
# have established divisibility.
self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv(
self.var_to_range[i0], ValueRanges.wrap(d)
))
# Propagate size-like-ness
if i0 in self.size_like:
self.size_like.add(i1)
self._set_replacement(i0, d * i1, "divisibility")
except NotImplementedError:
pass
return
# See: Note - On 0/1 specialization
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
# as a sentinel sometimes. Your sizevar isn't going to be
# anywhere near the max 64-bit integer anyway.
def _default_value_range(self) -> ValueRanges:
lower = 2 if self.specialize_zero_one else 0
return ValueRanges(lower, sys.maxsize - 1)
def _default_unspecified_value_range(self) -> ValueRanges:
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
@_lru_cache
def _simplify_floor_div(self, expr):
floor_divs = tuple(expr.atoms(FloorDiv))
# we expect floor_divs to be exact,
# and thus add the guards for the exact floordivs,
# even if tracing doesn't require them otherwise
for fd in reversed(floor_divs):
base, divisor = fd.args
mod_expr = Mod(base, divisor)
eq_expr = sympy.Eq(mod_expr, 0)
# add necessary mod guards
self.evaluate_expr(eq_expr)
return self.simplify(expr)
# We're about to add a guard/runtime assert, check if the ShapeEnv is frozen
# and if so issue a warning
def _check_frozen(self, expr, concrete_val):
if self.frozen:
self.counter["ignored_backward_guard"] += 1
signpost_event(
"dynamic",
"evaluate_expr_frozen",
{
**self.co_fields,
"ignored_guard": f"{expr} == {concrete_val}",
# no version = original state (this signpost is expected)
# version 2 = dynamic backwards is eagerly compiled
"version": 2,
},
)
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val, stack_info=True)
def _get_stack_summary(self, is_debug: bool = False):
fsummary = None
frame = inspect.currentframe()
try:
while frame is not None:
if frame.f_code.co_filename not in uninteresting_files():
fsummary = traceback.FrameSummary(
frame.f_code.co_filename,
frame.f_lineno,
frame.f_code.co_name,
)
break
frame = frame.f_back
finally:
del frame
# NB: this stack is truncated, but it's fine because the main
# stack_info will give you the rest of the info you need
maybe_user_loc = ""
user_tb = TracingContext.extract_stack()
if user_tb:
maybe_user_loc = " at " + format_frame(user_tb[-1])
maybe_extra_debug = ""
if is_debug and user_tb:
maybe_extra_debug = (
'\nUser Stack (most recent call last):\n' +
' (snipped, see stack below for prefix)\n' +
''.join(traceback.format_list(user_tb))
)
if is_debug and config.extended_debug_cpp:
cpp_stack = CapturedTraceback.extract(cpp=True)
maybe_extra_debug += "\nC++ stack trace:\n" + ''.join(cpp_stack.format())
elif is_debug:
maybe_extra_debug += (
"\nFor C++ stack trace, run with "
"TORCHDYNAMO_EXTENDED_DEBUG_CPP=1"
)
return fsummary, maybe_user_loc, maybe_extra_debug
def _log_guard(self, prefix: str, g, forcing_spec: bool):
if self.log.isEnabledFor(logging.INFO):
str_g = str(g)
is_debug = config.extended_debug_guard_added is not None and str_g == config.extended_debug_guard_added
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
maybe_more_info = ""
if not is_debug:
maybe_more_info = (
", for more info run with "
f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"'
)
self.log.info(
"%s %s [guard added]%s (%s)%s%s",
prefix if not forcing_spec else f"{prefix} (forcing_spec)",
str_g,
maybe_user_loc,
format_frame(fsummary),
maybe_more_info,
maybe_extra_debug,
stack_info=is_debug,
)
[docs] @lru_cache(256)
@record_shapeenv_event(save_tracked_fakes=True)
def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None,
expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False):
"""
Given an expression, evaluates it, adding guards if necessary
"""
# TODO: split conjunctions and evaluate them separately
# Don't track this one
@functools.lru_cache(None)
def compute_concrete_val():
if hint is None:
return self.size_hint(orig_expr)
else:
return sympy.sympify(hint)
# Check if:
# 1. 'translation_validation' is set
# 2. the corresponding 'fx_node' is not 'None'
# 3. the guard should not be suppressed
#
# If all of the above check, we create an FX node representing the
# actual expression to be guarded.
node = None
fresh = False
if (
self._translation_validation_enabled
and fx_node is not None
and not self._suppress_guards_tls()
and not size_oblivious
):
concrete_val = compute_concrete_val()
if concrete_val is sympy.true:
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
elif concrete_val is sympy.false:
neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
node, fresh = self._create_fx_call_function(torch._assert, (neg,))
else:
eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
node, fresh = self._create_fx_call_function(torch._assert, (eql,))
assert node is not None
# If this is a fresh node, we have to remember the event index that
# corresponds to this assertion node.
# Reason: so that, given an assertion node, we can replay the ShapeEnv
# events until the point where this assertion node was freshly created.
if fresh:
self._add_fx_node_metadata(node)
# After creating the FX node corresponding to orig_expr, we must make sure that
# no error will be raised until the end of this function.
#
# Reason: the translation validation may become invalid otherwise.
#
# If an error is raised before the end of this function, we remove the FX node
# inserted, and re-raise the error.
guard = None
tb = None
try:
if orig_expr.is_number:
self.log.debug("eval %s [trivial]", orig_expr)
if hint is not None:
assert orig_expr == hint, f"{orig_expr} != {hint}"
return orig_expr
expr = orig_expr
static_expr = self._maybe_evaluate_static(expr,
expect_rational=expect_rational,
size_oblivious=size_oblivious)
if static_expr is not None:
self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
if hint is not None:
assert static_expr == hint, f"{static_expr} != {hint}"
return static_expr
transmute_into_runtime_assert = False
concrete_val = None
if not (expr.free_symbols <= self.var_to_val.keys()):
# TODO: dedupe this with _maybe_evaluate_static
# Attempt to eliminate the unbacked SymInt
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
if not (new_expr.free_symbols <= self.var_to_val.keys()):
size_oblivious_result = None
if not size_oblivious:
size_oblivious_result = self._maybe_evaluate_static(
expr,
expect_rational=expect_rational,
size_oblivious=True
)
# Last ditch
if (
self.unbacked_var_to_val and
not (unsound_result := orig_expr.xreplace(self.unbacked_var_to_val)).free_symbols
):
log.warning("propagate_real_tensors evaluate_expr(%s) -> %s", orig_expr, unsound_result)
trace_structured(
"propagate_real_tensors",
metadata_fn=lambda: {
"expr": repr(orig_expr),
"result": repr(unsound_result),
"stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()),
},
)
transmute_into_runtime_assert = True
concrete_val = unsound_result
else:
raise self._make_data_dependent_error(
expr.xreplace(self.var_to_val),
expr,
size_oblivious_result=size_oblivious_result
)
else:
expr = new_expr
if concrete_val is None:
concrete_val = compute_concrete_val()
self._check_frozen(expr, concrete_val)
if (
config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
and isinstance(hint, bool)
and isinstance(expr, (sympy.Eq, sympy.Ne))
):
expr = sympy.Not(expr)
# Turn this into a boolean expression, no longer need to consult
# concrete_val
if concrete_val is sympy.true:
g = expr
elif concrete_val is sympy.false:
g = sympy.Not(expr)
else:
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
if transmute_into_runtime_assert:
self.defer_runtime_assert(
g,
f"propagate_real_tensors: {orig_expr} == {unsound_result}"
)
return concrete_val
if not self._suppress_guards_tls():
if isinstance(g, sympy.Rel):
# TODO: If we successfully eliminate a symbol via equality, it
# is not actually necessary to save a guard for the equality,
# as we will implicitly generate a guard when we match that
# input against the symbol. Probably the easiest way to
# implement this is to have maybe_guard_rel return a bool
# saying if it "subsumed" the guard (and therefore the guard
# is no longer necessary)
self._maybe_guard_rel(g)
if not self._allow_complex_guards_as_runtime_asserts:
# at this point, we've evaluated the concrete expr value, and have
# flipped/negated the guard if necessary. Now we know what to guard
# or defer to runtime assert on.
stack = CapturedTraceback.extract(skip=1)
guard = ShapeGuard(g, stack)
self.guards.append(guard)
else:
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
# and so the result here will be statically known
self.defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
except Exception:
if fresh:
self._remove_fx_node(node)
raise
else:
if not self._suppress_guards_tls():
if guard is not None: # we might have deferred this to runtime assert
self._log_guard("eval", g, forcing_spec=forcing_spec)
for s in g.free_symbols:
self.symbol_guard_counter[s] += 1
# Forcing_spec to avoid infinite recursion
if (
not forcing_spec and
config.symbol_guard_limit_before_specialize is not None and
self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize
):
# Force specialization
self.log.info(
"symbol_guard_limit_before_specialize=%s exceeded on %s",
config.symbol_guard_limit_before_specialize,
s
)
self.evaluate_expr(s, forcing_spec=True)
else:
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
return concrete_val
[docs] def cleanup(self):
"""
Break reference cycles.
This destroys the stacks. If you really want to keep them, we
just need some way to break references on code objects.
"""
for g in self.guards:
g.stack.cleanup()
for s in self.var_to_stack.values():
s.cleanup()
for ras in self.deferred_runtime_asserts.values():
for ra in ras:
ra.stack.cleanup()
[docs] @record_shapeenv_event(save_tracked_fakes=True)
def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
"""Create an assert that is checked at runtime
Args:
orig_expr (sympy.Expr): Boolean expression to assert is true
msg (str): Message to display on assertion failure
fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding
to the expression, if applicable
"""
expr = orig_expr
# TODO: split conjunctions and evaluate them separately
static_expr = self._maybe_evaluate_static(expr)
if static_expr is not None:
self.log.debug("runtime_assert %s == %s [statically known]", orig_expr, static_expr)
return static_expr
# Attempt to eliminate the unbacked SymInt
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
if not self.prefer_deferred_runtime_asserts_over_guards and new_expr.free_symbols <= self.var_to_val.keys():
# Do a normal guard
return self.evaluate_expr(new_expr, fx_node=fx_node)
# NB: Don't use new_expr as expr; it could contain gunk like shape0
# which we don't want to guard on
# If you're here because of this assert, read Note [Backwards runtime asserts]
# in torch/_inductor/graph.py
assert not self.runtime_asserts_frozen, expr
# OK, we're definitely doing a runtime assert now
if (
self._translation_validation_enabled
and fx_node is not None
and not self._suppress_guards_tls()
):
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
assert node is not None
if fresh:
self._add_fx_node_metadata(node)
self._check_frozen(expr, sympy.true)
if not self._suppress_guards_tls():
# eliminate symbols on equality tests / refine ranges
if isinstance(expr, sympy.Rel):
self._maybe_guard_rel(expr)
# canonicalise to remove equations that are trivially equal
orig_expr = expr
expr = canonicalize_bool_expr(expr)
stack = CapturedTraceback.extract(skip=1)
ra = RuntimeAssert(expr, msg, stack)
# TODO: Do this in a way that is less janky than int(s.name[1:])
cands = sorted((s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), key=lambda s: int(s.name[1:]))
# Is None when prefer_deferred_runtime_asserts_over_guards=True
# and the guard in question has no unbacked SymInts in front
ix = cands[-1] if cands else None
self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
self.num_deferred_runtime_asserts += 1
self._update_version_counter()
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)
else:
self._log_guard("runtime_assert [guard suppressed]", orig_expr, forcing_spec=False)
return True
# Refines the ranges of the variables present in 'guard'.
#
# This function tries to refine the range of the variables inside
# 'guard' by reasoning about it. Specifically, when 'guard' is a
# 'sympy.Relational' operation.
#
# It does mainly 3 things:
# 1. Tries to isolate a variable in the left-hand side
# 2. Compute the value range of the right-hand side
# 3. Update the value range of the variable, if better
def _refine_ranges(self, expr: sympy.Expr) -> None:
expr = self.simplify(expr)
for symbol in expr.free_symbols:
assert isinstance(symbol, sympy.Symbol)
if isinstance(self.var_to_val.get(symbol, None), SingletonInt):
# Skip var_to_range logic for SingletonInt which is only used
# for jagged layout NestedTensors today
continue
r = try_solve(expr, symbol)
if r is None or not (symbol.is_integer and r[1].is_integer):
# Range refinement only supports integer symbols for now.
# There are lots of SymPy bugs when it comes to comparing
# reals and integers, so we skip that for now.
continue
r_expr, rhs = r
vr = self.var_to_range[symbol]
lower, upper = vr.lower, vr.upper
rhs_vr = bound_sympy(rhs, self.var_to_range)
# Let's suppose that we have a preexisting range for x [0, 100].
# Now, we issue a guard x > y, where the range for y is [50, 150].
# Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen,
# refining x to [51, 100], since x must be greater than y, but the lowest
# y could be is 50.
#
# sympy.Eq may update both lower and upper bounds.
# sympy.G{t,e} may update the lower bound, only.
# sympy.L{t,e} may update the upper bound, only.
if lower < rhs_vr.lower and isinstance(r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)):
# Strictly greater relations allow us to refine a bit more, since
# x < y implies that the lower bound for x is: y + 1.
lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt))
if upper > rhs_vr.upper and isinstance(r_expr, (sympy.Eq, sympy.Le, sympy.Lt)):
upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt))
# Do nothing if the new value range is no better than what we already have.
if vr == ValueRanges(lower, upper):
continue
# Updates the range and the guards corresponding to each bound of the symbol.
self._update_var_to_range(symbol, ValueRanges(lower, upper))
# Clears the cache, since this update can change the result.
self._maybe_evaluate_static.cache_clear()
@lru_cache(maxsize=None)
@record_shapeenv_event()
def constrain_symbol_range(self, s: sympy.Symbol, compiler_min: int, compiler_max: int):
upd_vr = ValueRanges(compiler_min, compiler_max)
old_vr = self.var_to_range.get(s, ValueRanges.unknown())
self._update_var_to_range(s, upd_vr)
if (new_vr := self.var_to_range[s]) != old_vr:
log.info("constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper)
def _is_int(expr):
return isinstance(expr, SymInt) and expr.node.expr.is_number
# WARNING: This is legacy, DO NOT USE
def _is_dim_dynamic(t, d):
return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices
[docs]class PropagateUnbackedSymInts(torch.fx.Interpreter):
[docs] def run_node(self, n: torch.fx.Node):
"""
Run an FX node, propagating unbacked Symbol bindings to the new fake tensor
"""
from torch._guards import detect_fake_mode
result = super().run_node(n)
rebind_unbacked(detect_fake_mode().shape_env, n, result)
return result