Source code for torch.fx.experimental.symbolic_shapes
# mypy: ignore-errors"""``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting withour symbolic shapes reasoning system that is used heavily in torch.compile. Althoughthis is not generally considered public API, when writing framework code in PyTorchas well as extensions to PyTorch (e.g., in custom operator implementations), you mayneed to make use of these APIs to setup dynamic shapes support appropriately."""importbuiltinsimportcollectionsimportfunctoolsimportinspectimportitertoolsimportloggingimportmathimportoperatorimportreimportsysimportthreadingimporttracebackfromcollectionsimportdefaultdictfromcontextlibimportcontextmanagerfromdataclassesimportdataclass,fieldfromenumimportEnumfromfunctoolsimportlru_cachefromtypingimport(Any,cast,Callable,Dict,Iterable,List,Optional,Sequence,Set,Tuple,Type,Union,TYPE_CHECKING)fromtyping_extensionsimportTypeAliasimporttorchimporttorch.fximporttorch.fx.tracebackasfx_tracebackfromtorch.fx.experimentalimport_configasconfigfromtorch.fx.experimental.recordingimport(FakeTensorMeta,ShapeEnvEvent,record_shapeenv_event,replay_shape_env_events,shape_env_check_state_equal)fromtorch.fx.experimental.sym_nodeimportSymNode,SymTypes# NB: The sym_* functions are used via getattr() and must be imported here.fromtorchimportSymBool,SymFloat,SymIntfromtorch._guardsimportShapeGuard,Source,TracingContextfromtorch.utils._python_dispatchimportis_traceable_wrapper_subclassfromtorch.utils._sympy.functionsimportFloorDiv,Mod,IsNonOverlappingAndDenseIndicatorfromtorch.utils._sympy.solveimporttry_solvefromtorch.utils._sympy.value_rangesimportbound_sympy,SymPyValueRangeAnalysis,ValueRanges,ValueRangeErrorfromtorch.utils._sympy.singleton_intimportSingletonIntfromtorch.utils._tracebackimportformat_frame,CapturedTracebackfromtorch._utils_internalimportsignpost_eventfromtorch._subclasses.meta_utilsimportis_sparse_anyfromtorch._loggingimportLazyStringifTYPE_CHECKING:fromtorch._dynamo.sourceimportTensorPropertySourceInputList=ListDimList=Listlog=logging.getLogger(__name__)classGuardOnDataDependentSymNode(RuntimeError):passimportsympyfromsympy.printing.strimportStrPrinterfromsympy.printing.precedenceimportprecedence,PRECEDENCEaten=torch._ops.ops.aten# type: ignore[has-type]__all__=["has_symbolic_sizes_strides","create_contiguous","ShapeEnv","is_concrete_int","guard_int","guard_float","guard_scalar","canonicalize_bool_expr","hint_int","SYMPY_INTERP","free_symbols","is_symbol_binding_fx_node","is_concrete_bool","is_nested_int","SHAPEENV_EVENT_KEY","CURRENT_NODE_KEY","has_free_symbols","sym_eq","SymbolicContext","StatelessSymbolicContext","StatefulSymbolicContext","SubclassSymbolicContext","statically_known_true","guard_size_oblivious",]# FX node metadata keys for symbolic shape FX graph.SHAPEENV_EVENT_KEY="shapeenv_event"CURRENT_NODE_KEY="current_node"# These are modules that contain generic code for interacting with ShapeEnv# which are unlikely to identify a particular interesting guard statement@lru_cache(None)defuninteresting_files()->Set[str]:importtorch._inductor.sizevarsimporttorch._library.abstract_implimporttorch._subclasses.meta_utilsimporttorch._subclasses.fake_tensormods=[sys.modules[__name__],torch.fx.experimental.recording,torch.fx.experimental.sym_node,torch.fx.interpreter,torch,torch._inductor.sizevars,torch._library.abstract_impl,torch._subclasses.meta_utils,torch._subclasses.fake_tensor,]return{inspect.getfile(m)forminmods}# We don't bother with the metaclass as all of the dispatching logic happens# entirely from Python## Didn't bother with ancestors for now, unlikely to have multiple modes for# symints right nowclassConstraintViolationError(RuntimeError):passdefhas_symbolic_sizes_strides(elem)->bool:returnelem._has_symbolic_sizes_stridesInt=Union[torch.SymInt,int]defcreate_contiguous(shape:Sequence[Int])->List[Int]:strides:List[Int]=[1]fordiminreversed(shape[:-1]):strides.append(dim*strides[-1])returnlist(reversed(strides))
[docs]defhint_int(a:Union[torch.SymInt,int],fallback:Optional[int]=None)->int:""" Retrieve the hint for an int (based on the underlying real values as observed at runtime). If no hint is available (e.g., because data dependent shapes), if fallback is not None, use that instead (otherwise raise an error). """ifisinstance(a,torch.SymInt):returna.node.require_hint(fallback)asserttype(a)isint,areturna
[docs]defis_concrete_int(a:Union[int,SymInt])->bool:r""" Utility to check if underlying object in SymInt is concrete value. Also returns true if integer is passed in. Args: a (SymInt or int): Object to test if it int """assertisinstance(a,(SymInt,int))ifisinstance(a,int):returnTrueifisinstance(a.node.expr,sympy.core.numbers.Integer):returnTruereturnFalse
# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime.# So make sure only type checker evaluates this alias.# Xref: https://www.internalfb.com/diff/D53324783SympyBoolean:TypeAlias="sympy.logic.boolalg.Boolean"
[docs]defguard_size_oblivious(expr:Union[torch.SymBool,bool])->bool:""" Perform a guard on a symbolic boolean expression in a size oblivious way. This is typically used when a non-oblivious test would result in a guard on a data dependent value of which we don't know the value of at compile time. When a guard is tested this way, we may diverge in behavior from how regular PyTorch semantics would treat it. For more information, see https://github.com/pytorch/pytorch/pull/118579 """ifisinstance(expr,torch.SymBool):returnexpr.node.guard_size_oblivious("",0)else:assertisinstance(expr,bool)returnexpr
[docs]defcanonicalize_bool_expr(expr:SympyBoolean)->SympyBoolean:r""" Canonicalize a boolean expression by transforming it into a lt / le inequality and moving all the non-constant terms to the rhs. We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr recursively nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 Args: expr (sympy.Expr): Expression to canonicalize """# Canonicalise an inequality by transforming it into a lt / le# inequality and moving all the non-constant terms to the rhs# We canonicalise And / Ors / Not via cnf# nb. Relational.canonical in sympy is broken# https://github.com/sympy/sympy/issues/25924ifnotisinstance(expr,(sympy.Rel,sympy.And,sympy.Or,sympy.Not,sympy.Eq,sympy.Ne)):returnexprifisinstance(expr,(sympy.And,sympy.Or,sympy.Not)):expr=sympy.logic.boolalg.to_cnf(expr)return_canonicalize_bool_expr_impl(expr)
def_canonicalize_bool_expr_impl(expr:SympyBoolean)->SympyBoolean:""" After canonicalization, we are guaranteed to have eliminated Ge/Gt relations (rewriting them to Le/Lt, respectively). """ifisinstance(expr,(sympy.And,sympy.Or)):returntype(expr)(*map(canonicalize_bool_expr,expr.args))opposite={sympy.Gt:sympy.Lt,sympy.Ge:sympy.Le}ifisinstance(expr,tuple(opposite.keys())):lhs=expr.rhs-expr.lhst=opposite[type(expr)]else:assertisinstance(expr,(sympy.Lt,sympy.Le,sympy.Eq,sympy.Ne))lhs=expr.lhs-expr.rhst=type(expr)rhs=0ifisinstance(lhs,sympy.Add):cts=[]variables=[]forterminlhs.args:ifterm.is_number:cts.append(term)else:variables.append(term)lhs=sympy.Add(*variables)rhs=-sympy.Add(*cts)returnt(lhs,rhs)
[docs]defis_concrete_bool(a:Union[bool,SymBool])->bool:r""" Utility to check if underlying object in SymBool is concrete value. Also returns true if integer is passed in. Args: a (SymBool or bool): Object to test if it bool """assertisinstance(a,(SymBool,bool))ifisinstance(a,bool):returnTrueifisinstance(a.node.expr,(sympy.logic.boolalg.BooleanTrue,sympy.logic.boolalg.BooleanFalse)):returnTruereturnFalse
defis_nested_int(s):returnisinstance(s,torch.SymInt)ands.node.is_nested_int()def_iterate_exprs(val:Union[SymInt,torch.Tensor])->Iterable[sympy.Basic]:ifisinstance(val,SymTypes):# This allow applies to the jagged layout NestedTensor case as# nested ints are not symbolicifis_symbolic(val):yieldval.node.exprelifisinstance(val,sympy.Basic):yieldvalelifisinstance(val,(int,float,bool)):passelifis_sparse_any(val):yield from_iterate_exprs(val.size())elifisinstance(val,torch.Tensor):yield from_iterate_exprs(val.size())yield from_iterate_exprs(val.stride())yield from_iterate_exprs(val.storage_offset())elifisinstance(val,(tuple,list)):forsinval:yield from_iterate_exprs(s)elifvalisNone:passelse:raiseAssertionError(f"cannot extract sympy expressions from {val}{type(val)}")deffree_symbols(val:Union[SymInt,torch.Tensor])->Set[sympy.Symbol]:ifvalisNone:returnset()itr=_iterate_exprs(val)# we need at least 1 to call union, so we hand code the identitytry:first_expr=next(itr)exceptStopIteration:returnset()returnfirst_expr.free_symbols.union(*(e.free_symbolsforeinitr))
[docs]defhas_free_symbols(val:Union[SymInt,torch.Tensor])->bool:"""Faster version of bool(free_symbols(val))"""returnnotall(e.is_numberforein_iterate_exprs(val))
# Like free_symbols, but filtered to only report unbacked symbolsdeffree_unbacked_symbols(x):# NB: keep synced with is_unbacked_symintreturn{sforsinfree_symbols(x)ifs.name.startswith(("u","f"))}# WARNING: Don't use this on Dynamo produced graphs, they don't have meta# setup!defis_symbol_binding_fx_node(node)->Optional[sympy.Symbol]:if(node.op=="placeholder"and"val"innode.metaandisinstance(node.meta["val"],torch.SymInt)andisinstance(node.meta["val"].node.expr,sympy.Symbol)):returnnode.meta["val"].node.exprreturnNonedeffind_symbol_binding_fx_nodes(graph):return{node.meta["val"].node.expr:nodefornodeingraph.nodesifis_symbol_binding_fx_node(node)}
[docs]defdefinitely_true(a):""" Returns True only if we can tell that a is True, possibly introducing a guard in the process. If a depends on some unbacked SymInt, we may return False even though there may exist a possible value of the SymInt that would cause the expression to return True. When is it appropriate to use definitely_true? First, if you can use a higher level combinator like parallel_or/parallel_and, prefer using those instead, they are definitely safe (modulo short-circuiting). Second, it can be used if the program would behave equivalently if definitely_true always returned False (parallel_or/parallel_and are examples of this pattern, modulo short-circuiting). Finally, it even be OK if the program wouldn't behave equivalently, so long as the change is semantics preserving. It can be semantics preserving if the program errors in more cases than it did previously (but otherwise behaves identically), or if it changes some quantity in a way that doesn't matter (e.g., strides often fall in this bucket.) """ifisinstance(a,SymBool):ifa.node.has_hint():returnguard_bool(a)else:returnFalsereturnbool(a)
[docs]defdefinitely_false(a):""" Returns True only if we can tell that a is False, possibly introducing a guard in the process. If a depends on some unbacked SymInt, we may return False even though there may exist a possible value of the SymInt that would cause the expression a to be False. See definitely_true for more usage guidance. """ifisinstance(a,SymBool):ifa.node.has_hint():returnnotguard_bool(a)else:returnFalsereturnnotbool(a)
[docs]defstatically_known_true(x:Union[bool,SymBool])->bool:"""Returns True if x can be simplified to a constant and is true. .. note:: This function doesn't introduce new guards, so the expression may end up evaluating to true at runtime even if this function returns False. Args: x (bool, SymBool): The expression to try statically evaluating """ifisinstance(x,SymBool):expr=x.node.exprshape_env=x.node.shape_envtry:simplified=shape_env._maybe_evaluate_static(expr)ifsimplifiedisnotNone:returnbool(simplified)exceptException:log.debug("Could not simplify %s",expr)returnFalseassertisinstance(x,bool)returnx
[docs]defparallel_or(*args):""" Evaluate the logical OR of several arguments, avoiding guarding on unbacked SymInts if another argument is definitely True. """ifany(statically_known_true(a)forainargs):returnTrueifany(definitely_true(a)forainargs):returnTruereturnany(args)
[docs]defparallel_and(*args):""" Evaluate the logical FALSE of several arguments, avoiding guarding on unbacked SymInts if another argument is definitely False. """ifany(statically_known_true(torch.sym_not(a))forainargs):returnFalseifany(definitely_false(a)forainargs):returnFalsereturnall(args)
[docs]defsym_eq(x,y):""" Like ==, but when run on list/tuple, it will recursively test equality and use sym_and to join the results together, without guarding. """if(isinstance(x,tuple)andisinstance(y,tuple))or(isinstance(x,list)andisinstance(y,list)):iflen(x)!=len(y):returnFalsereturnfunctools.reduce(operator.and_,map(sym_eq,x,y),True)elifisinstance(x,(int,torch.SymInt))andisinstance(y,(int,torch.SymInt)):returnx==yelse:raiseAssertionError(f"unexpected sym_eq between {type(x)}{type(y)}")
defguard_scalar(a):ifisinstance(a,(SymBool,bool)):returnguard_bool(a)elifisinstance(a,(SymInt,int)):returnguard_int(a)elifisinstance(a,(SymFloat,float)):returnguard_float(a)else:raiseAssertionError(f"unrecognized scalar {a}")@record_shapeenv_event()def_constrain_symbol_range(shape_env,s:sympy.Symbol,compiler_min:int,compiler_max:int):upd_vr=ValueRanges(compiler_min,compiler_max)old_vr=shape_env.var_to_range.get(s,ValueRanges.unknown())new_vr=shape_env.var_to_range[s]=old_vr&upd_vrifnew_vr!=old_vr:log.info("_constrain_symbol_range %s [%s, %s]",s,new_vr.lower,new_vr.upper)def_advise_is_size(a):""" Don't use this directly; use torch._check_is_size instead. This is a softer version of _constrain_range_for_size (with min=0, max=Inf). Instead of forcibly constraining a variable (and erroring if we failed to constrain it), it will simply advise us that a size is constrained in some way. We will always defer a runtime assert for this constraint if we cannot prove it at compile-time, but we we only *sometimes* learn useful extra information at compile-time with this information. This is in contrast to constrain_range_for_size, where if you don't call that on a fresh unbacked symint, chances are we will choke. TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed code. Right now this is only really used in code with AOTAutograd trace through, so it is not a big problem that this isn't supported, but in principle all of this code should be Dynamo'able too. TODO: I didn't support min/max because I didn't have a use case where this actually helped. In principle we can support it, it just makes the implementation below more complicated. """# This must always succeed, because the sole allowed caller _check_is_size# was responsible for expect_true'ing thisasserta>=0# NB: it's important not to constrain range for size for *hinted* SymInts,# because it is not only unsound, it will immediately trip our asserts# that hints have to be consistent with static analysis! If you somehow# have an unbounded SymInt that later constrains to 1, this will be# inconsistent with the rangeif(isinstance(a,SymInt)andisinstance(a.node,SymNode)andnota.node.has_hint()andisinstance(a.node.expr,sympy.Symbol)):_constrain_range_for_size(a)@record_shapeenv_event()def_constrain_range_for_size(a,min:Optional[int]=None,max:Optional[int]=None):""" This function is NOT INTENDED to be used by itself. """ifisinstance(a,(SymFloat,SymBool)):raiseValueError("Constraining SymFloat/SymBool is nyi")assertisinstance(a,SymInt),"can only constrain range for SymInt"assertisinstance(a.node.expr,sympy.Symbol),"constraining non-Symbols NYI"ifminisNone:min=0ifmaxisNone:max=sympy.ooifmax<min:raiseValueError("Maximum value to constrain_as_size can't be less than the specified min value, ""received min={min} and max={max}")_constrain_symbol_range(a.node.shape_env,a.node.expr,compiler_min=min,compiler_max=max,)a.node.shape_env.size_like.add(a.node.expr)# inclusive both ways
[docs]@record_shapeenv_event()defconstrain_range(a,*,min:Optional[int],max:Optional[int]=None):""" Applies a constraint that the passed in SymInt must lie between min-max inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning that it can be used on unbacked SymInts). If min/max are None, we assume that the dimension is unbounded in that direction. Repeated application of constrain_range intersects the ranges. This is a fairly low level API that doesn't have a lot of safety guarantees (TODO: provide higher level APIs). Currently, we use this API in the following circumstance: when we allocate an unbacked SymInt, denoting an integer quantity which is data dependent, we ordinarily do not know anything about what values it may take. This means that any sort of guard on it will immediately fail. However, in many cases, we know something about the unbacked SymInt: for example, we know that nonzero(x).size(0) must be >= 0. We use constrain_range to narrow the possible range, declaring that negative symbols are impossible. This permits to definitely answer True to queries like 'nnz >= 0', even if we don't know what the actual (hinted) value of 'nnz' is. In fact, we actually use constrain_range to unsoundly discharge common guards: for an unbacked SymInt produced by nonzero, we will also assume that it is not equal to 0/1 (even though these are perfectly possible values at runtime), because we generally expect graphs that are valid for N=2 to also be valid for N=1. """ifminisNone:min=-sympy.ooifmaxisNone:max=sympy.ooifmax<min:raiseValueError("Maximum value to constrain_as_size can't be less than the specified min value, ""received min={min} and max={max}")ifisinstance(a,int):ifnot(min<=a<=max):raiseValueError(f"Invalid value {a} for range [{min}:{max}]")returnifisinstance(a.node.expr,sympy.Integer):ifnot(min<=int(a.node.expr)<=max):raiseValueRangeError(f"Invalid value {int(a.node.expr)} for range [{min}:{max}]")returnassertisinstance(a.node.expr,sympy.Symbol),"constraining non-Symbols NYI"# TODO: Shouldn't we install a guard if the symbol is backed? Or is the# semantics that this is an "unchecked" assert (but it this actually# something useful? Might be better to restrict only for unbacked# SymInt)._constrain_symbol_range(a.node.shape_env,a.node.expr,compiler_min=min,compiler_max=max,)
[docs]@record_shapeenv_event()defconstrain_unify(a,b):""" Given two SymInts, constrain them so that they must be equal. NB: this will not work with SymInts that represent nontrivial expressions (yet!) """# TODO: this does not install a deferred runtime assert yet# TODO: Maybe dedupe this with _maybe_guard_rel?ifnotisinstance(a,SymInt):ifnotisinstance(b,SymInt):asserta==belse:assertisinstance(b.node.expr,sympy.Symbol),"constraining non-Symbols NYI"shape_env=b.node.shape_envshape_env.replacements[b.node.expr]=sympy.Integer(a)else:# TODO: Actually, we can support this as long as one of them is a symbol.# NB: We can't actually do "unification" as our operators are not# injectiveassertisinstance(a.node.expr,sympy.Symbol),"constraining non-Symbols NYI"shape_env=a.node.shape_envifnotisinstance(b,SymInt):shape_env.replacements[a.node.expr]=sympy.Integer(b)else:asserta.node.shape_envisb.node.shape_envassertisinstance(b.node.expr,sympy.Symbol),"constraining non-Symbols NYI"new_var=shape_env._find(a.node.expr)shape_env.replacements[b.node.expr]=new_var
# Assume that a boolean is true for the purposes of subsequent symbolic# reasoning. This will keep track of corresponding runtime checks to verify# that the result is upheld: either as a regular guard, or as a special set# of asserts which are triggered when an unbacked SymInt is allocated.## DO NOT use this function for these cases:## - This is inappropriate for "branching" conditions (where both# true and false result in valid programs). We will always assume# the condition evaluates true, and so it will never be possible# to trace the false condition when you use it. For true branching# on unbacked SymInts, you must use torch.cond; if you incorrectly# use expect_true in this case, you will make the false branch# unreachable (as we will simply assume that only the true branch# is ever exercised).## - This is inappropriate for situations where you know some other system# invariant guarantees that this property holds, since you don't# really need to insert a runtime check in that case. Use something# like constrain_range in that case.## This API has a hitch. To avoid having to reimplement error reporting# capabilities, this function CAN return False. The invariant is that# the surrounding code must raise an error when this function returns# False. This is quite low level, so we recommend using other functions# like check() which enforce this in a more intuitive way.## By the way, this name is a nod to the __builtin_expect macro,# which is used similarly (but unlike __builtin_expect, you MUST fail# in the unlikely branch.) (I think expect is a good name; in recent# versions of C++, this is replaced with [[likely]], which is weaker# and not accurate for this function!)defexpect_true(a,skip:int=0):ifisinstance(a,SymBool):# TODO: check perf implications of thisframe=inspect.currentframe()for_inrange(skip+1):# always run this loop at least onceframe=frame.f_backreturna.node.expect_true(frame.f_code.co_filename,frame.f_lineno)asserttype(a)isbool,areturnadefguard_bool(a):ifisinstance(a,SymBool):returna.node.guard_bool("",0)# NB: uses Python backtraceasserttype(a)isbool,areturnadefguard_int(a):ifisinstance(a,SymInt):returna.node.guard_int("",0)# NB: uses Python backtraceasserttype(a)isint,areturnadefguard_float(a):ifisinstance(a,SymFloat):returna.node.guard_float("",0)# NB: uses Python backtraceassertisinstance(a,float),areturna# Given a GraphModule, return all the FakeTensors for all the placeholdersdeffx_placeholder_vals(gm):return[n.meta['val']forningm.graph.nodesifn.op=="placeholder"]deffx_placeholder_targets(gm):return[n.targetforningm.graph.nodesifn.op=="placeholder"]# Given a GraphModule and arguments to run it with, evaluate that the guards# for its associated ShapeEnv are satisfied by the passed arguments. This# WILL check for duck sizing.defeval_guards(gm,*args,ignore_static=True):returngm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm),args,ignore_static=ignore_static)defbind_symbols(gm,*args):returngm.shape_env.bind_symbols(fx_placeholder_vals(gm),args)def_assert_bound_is_rational(expr:sympy.Expr,bound:ValueRanges):""" We assert that the bounds are either Boolean, or not finite, or can be computed in exact prevision via rational arithmetic. The only exception to this is the rare case when the user calls `sqrt(s0)` sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) """assertbound.lower.is_rationalorbound.lower.is_Booleanornotbound.lower.is_finiteorexpr.has(sympy.Pow),(bound,expr)assertbound.upper.is_rationalorbound.upper.is_Booleanornotbound.upper.is_finiteorexpr.has(sympy.Pow),(bound,expr)
[docs]classDimDynamic(Enum):""" Controls how to perform symbol allocation for a dimension. It is always sound to default this to DYNAMIC, but the policies DUCK and STATIC can result in better trace-time and compile-time performance, as they reduce the number of allocated symbols and generally make your graph more static. NB: If we notice you've applied a constraint to the dimension, we will force it to DYNAMIC for simplicity. DimDynamic is controlled by a variety of higher level UX features. Currently: - In eager mode, the default policy is DUCK. - The default is changed to STATIC with assume_static_by_default. - An individual dim is marked DYNAMIC if you mark_dynamic_dim. - In export mode, the default policy is STATIC. - An individual dim is marked DYNAMIC if you mention it as dynamic_dim in the constraints kwarg. """# Treat the dimension symbolicallyDYNAMIC=0# Treat the dimension symbolically, but if its hint matches another# dynamic dimension, unify the two symbols ("duck sizing")DUCK=1# Treat the dimension statically based on its hintSTATIC=2
# NB: These constraints affect both clients and backends: given some# constraint C, the client must pass inputs that satisfy the constraint,# while a backend must not introduce guards BEYOND this constraint.# For clarity, we document the implications on both sides for both the client# and the backend.## NB: These constraints are on a *single* dimension. In principle, we could# also have multi-dimension constraints, but our guess is that this is not# actually useful and so we are not supporting it right now.## NB: Strict constraints are typically only suitable for export, as in eager# a backend like inductor may validly introduce extra, discretionary guards# to improve performance of code. A StrictMinMaxConstraint would be brittle# under future optimizations performed by inductor; we don't guarantee# eager code with StrictMinMaxConstraint will keep working in the future!@dataclass(frozen=True)classConstraint:warn_only:bool
[docs]@dataclass(frozen=True)classStrictMinMaxConstraint(Constraint):""" For clients: the size at this dimension must be within 'vr' (which specifies a lower and upper bound, inclusive-inclusive) AND it must be non-negative and should not be 0 or 1 (but see NB below). For backends: there must not be any guards on this dimension which are not implied by the given lower and upper bound. Regardless of the lower bound, the backend can assume the size is non-negative and that it is not 0 or 1. An unbounded StrictMinMaxConstraint can be thought of as a strict version of "RelaxedUnspecConstraint". NB: Export will often unsoundly assume that a graph works for 0/1, even though at trace time we assumed size is not 0 or 1. The idea is that if we produce a graph that works for a range of values, it will be OK for N=0/1 too. """vr:ValueRanges
[docs]defrender(self,source:Source):"""Format the constrain equation"""# TODO: better printing for -oo and ooreturnf"{self.vr.lower} <= {source.name()} <= {self.vr.upper}"
[docs]@dataclass(frozen=True)classRelaxedUnspecConstraint(Constraint):""" For clients: no explicit constraint; constraint is whatever is implicitly inferred by guards from tracing. For backends: there must exist at least TWO possible values for the size at this dimension which satisfy the guards for this dimension. In other words, this constraint helps us distinguish between "we don't care if this dimension specializes or not" versus "this dimension must be unspecialized." However, this constraint doesn't say very much about what specialization is permitted; for example, if we guard on a size being even, this would still be acceptable under an unspec constraint. This makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler may add constraints to otherwise dynamic dimensions; we can't assert that there are NO guards as this is brittle because compilers should be able to add extra constraints. If you want to assert that there are no guards, use StrictMinMaxConstraint with an unbounded ValueRanges. """defrender(self,source:Source):returnf"RelaxedUnspecConstraint({source.name()})"
# NB: None here indicates the client constraint is whatever is implicitly# inferred by guards from tracing, and that a backend can add whatever guards# it wants (including fully specializing the value).DimConstraint=Union[StrictMinMaxConstraint,RelaxedUnspecConstraint,None]
[docs]@dataclass(frozen=True)classEqualityConstraint(Constraint):""" Represent and decide various kinds of equality constraints between input sources. A "source pair" is a pair of input sources for dynamic dimensions that are specified equal. We represent `source_pairs` in a union-find forest so that we can efficiently check whether two such sources are transitively equal. A "derived equality" relates an input source to an expression over a root. The root can be another input source, corresponding to some dynamic dimension, or a phantom symbol that does not directly represent any dynamic dimension. We represent `derived_equalities` involving input sources in a transitively-closed map so that we can efficiently check whether an input source is transitively equal to a given expression over another input source. (NOTE: In contrast, it is easy to decide whether an input source is transitively equal to a given expression over a phantom symbol; such expressions are already in canonical form and so the problem reduces to symbolic expression equality.) """source_pairs:List[Tuple[Source,Source]]derived_equalities:List[Tuple[Source,Union[Source,sympy.Symbol],Callable[[sympy.Expr],sympy.Expr]]]phantom_symbols:List[sympy.Symbol]def__post_init__(self):"""Pre-processing to answer queries `is_equal` and `is_derived` below. Example: Suppose we are given: source_pairs [a = b, b = c] derived_equalities [d = c + 1, e = d - 1] We first construct a union find with source_pairs: _parents = {a: a, b: a, c: a} Then we compute canonical symbolic expressions, recursively applying derived_equalities until we bottom out: _defs = {d: c + 1, e: (c + 1) - 1 aka c} """# self._parents is a map from input sources to input sources where, conceptually,# these are directed edges in a union-find forest_parents:Dict[Source,Source]={}object.__setattr__(self,"_parents",_parents)# self._defs is a map from input sources to "canonical" symbolic expressions,# i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,# not derived Dims)_defs:Dict[Source,sympy.Expr]={}object.__setattr__(self,"_defs",_defs)forsource1,source2inself.source_pairs:# preprocess into a union-find forestself._union(self._find(source1),self._find(source2))forsource,root,fninself.derived_equalities:# preprocess into a transitively-closed map# NOTE(avik): we reuse the union-find forest for canonicalizing input sourcesifisinstance(root,sympy.Symbol):self._defs[self._find(source)]=fn(root)else:self._defs[self._find(source)]=fn(self._rewrite(root))def_find(self,source):# chase edges to find the root of this equivalence classifsourceinself._parents:returnself._find(self._parents[source])else:returnsourcedef_union(self,root1,root2):# merge two equivalence classes by adding an edge from one root to the otherifroot1!=root2:self._parents[root1]=root2def_rewrite(self,src):# always represent the given source by the root of its equivalence classsrc=self._find(src)ifsrcinself._defs:# simply look up the definition if it exists# NOTE(avik): This works because definitions are always transitively-closed;# otherwise we would have to do recursive rewriting.returnself._defs[src]else:# otherwise, create a symbol representing the sourcereturnsympy.Symbol(src.name())defis_equal(self,source1,source2):return(# check whether source1 and source2 have the same rootself._find(source1)==self._find(source2)or# check whether source1 is derived equal to source2self.is_derived(source1,source2,lambdax:x))defis_derived(self,src,symbol_src,fn):# check whether both src and symbol_src have the same definitionreturnself._rewrite(src)==fn(self._rewrite(symbol_src))
def_assert_symbol_context(symbolic_context):assertisinstance(symbolic_context,SymbolicContext),"Invalid symbolic_context object"asserttype(symbolic_context)isnotSymbolicContext,"Illegal usage of symbolic_context ABC"
[docs]@dataclass(frozen=True)classSymbolicContext:""" Data structure specifying how we should create symbols in ``create_symbolic_sizes_strides_storage_offset``; e.g., should they be static or dynamic. This is an abstract base class because we are probably going to add another version of this that says "use exactly these SymInts, don't allocate fresh symbols." """pass
[docs]@dataclass(frozen=True)classStatelessSymbolicContext(SymbolicContext):""" Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. This will cause fresh symbols to be allocated """dynamic_sizes:DimList[DimDynamic]constraint_sizes:DimList[DimConstraint]=None# If the tensor is a view, this should be populated for the base. It contains# information on how to allocate symbols when recursively fakeifying the base# during view fake-ification.view_base_context:Optional[SymbolicContext]=None# TODO: add storage offset and stride symbolic_contextdef__post_init__(self):ifself.constraint_sizesisNone:object.__setattr__(self,'constraint_sizes',[None]*len(self.dynamic_sizes))
# note [Tensor Fakification and Symbol Caching]## As of the time of this note, dynamo creates a fresh fake tensor mode for backends.# The reason we do this is because there are certain classes of operations, namely,# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor# state at the end of a dynamo trace is different than the fake tensor state at the beginning# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation,# view relationships, etc.## As we create a new fake mode, we also lose the memoization that comes with it. Rather than# transfer the memoization cache, we instead transfer the shape env. However, with this# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across# recompilations.## In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext.# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env# is used.# TODO(voz): Shape env validation
[docs]@dataclass(frozen=True)classStatefulSymbolicContext(StatelessSymbolicContext):""" Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via a symbolic_context determination as given by a cache of Source:Symbol. A cache hit will reuse a stored symbol, and a cache miss will write to this cache. This behaves like StatelessSymbolicContext, except the cache supersedes the other values - dynamic_sizes and constraint_sizes will not be read if we cache hit. It is the cache owners responsibility to maintain the lifecycle of the cache w/r/t different shape_envs, clearing, etc. """tensor_source:Source=None# Why is this keyd on int first?# That integer is actually the id of the shape_env. This cache short-circuits symbol# creation, and we must store it per shape env. Now, while tracing invariants are a single# shape env per tracing context, and every new frame gets a new shape_env. So where would we have# multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events# is invoked, and creates a new shape_env. Replaying events against this new shape_env will# cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never# get recorded in var_to_val, etc.# TODO(voz): consider a weakref to the shape_env hereshape_env_to_source_to_symbol_cache:Dict[int,Dict["TensorPropertySource","sympy.Expr"]]=Nonedef__post_init__(self):# The None default is annoying, but required because of dataclass limitationsassertself.tensor_sourceisnotNoneifnotself.shape_env_to_source_to_symbol_cache:object.__setattr__(self,'shape_env_to_source_to_symbol_cache',{})
[docs]@dataclass(frozen=True)classSubclassSymbolicContext(StatefulSymbolicContext):""" The correct symbolic context for a given inner tensor of a traceable tensor subclass may differ from that of the outer symbolic context. This structure allows for this flexibility, with inner symbolic contexts mapped via attr -> symbolic context. """inner_contexts:Dict[str,SymbolicContext]=Nonedef__post_init__(self):super().__post_init__()ifself.inner_contextsisNone:self.inner_contexts={}
defis_symbolic(val:Union[int,SymInt,float,SymFloat,bool,SymBool])->bool:ifisinstance(val,(int,float,bool)):returnFalsereturnval.node.is_symbolic()IndicatorTypes=(IsNonOverlappingAndDenseIndicator,)@lru_cache(256)defsafe_expand(r):ifhasattr(r,'expand'):try:returnsympy.expand(r)exceptRecursionError:log.warning("RecursionError in sympy.expand(%s)",r)returnrelse:returnrdeferror():raiseAssertionError("shouldn't be hit")# TODO: Deduplicate this with torch/_prims_common/__init__.pydefeval_is_non_overlapping_and_dense(sizes,strides):returnint(guard_bool(_eval_is_non_overlapping_and_dense(sizes,strides)))def_eval_is_non_overlapping_and_dense(sizes,strides):dim=len(sizes)# Short-circuits for tensors of rank one, which are# non-overlapping and "dense" if their stride is one# or it is a 0/1 element tensorifdim==1:returnstrides[0]==1orsizes[0]<2# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous# Sorts (length, stride) pairs by stridelengths_and_strides=sorted(zip(sizes,strides),key=operator.itemgetter(1))# Unlike the C++ code, we don't move the 0/1 size dimensions to the# end. So we have to keep going for this code.expected_stride=1forlength,strideinlengths_and_strides:iflength==1:continueifstride!=expected_stride:returnFalseexpected_stride*=lengthreturnTruedefcast_symbool_to_symint_guardless(symbool:torch.SymBool)->torch.SymInt:int_sym=sympy.Piecewise((1,symbool.node.expr),(0,True))returnsymbool.node.shape_env.create_symintnode(int_sym,hint=int(symbool.node.require_hint()))SYMPY_INTERP={'Abs':operator.abs,'Eq':operator.eq,'Ne':operator.ne,'Gt':operator.gt,'Lt':operator.lt,'Le':operator.le,'Ge':operator.ge,'Min':min,'Max':max,'Mod':operator.mod,'FloorDiv':operator.floordiv,'TrueDiv':operator.truediv,'IsNonOverlappingAndDenseIndicator':eval_is_non_overlapping_and_dense,'floor':math.floor,'ceiling':math.ceil,'cast_symbool_to_symint_guardless':cast_symbool_to_symint_guardless,'Round':builtins.round,'RoundDecimal':builtins.round,}def_lru_cache(fn,maxsize=None):""" Wrapper around lru_cache that clears when new info about shapes has been updated. Use lru_cache if the output is always the same, regardless of the constraints we know now (i.e. evaluate_expr) Use _lru_cache otherwise. Also note that this depends on _update_version_counter being called on the shape environment whenever the constraints are updated, otherwise the cache will not be cleared. """fn_cache=lru_cache(maxsize)(fn)prior_version=0ifconfig.validate_shape_env_version_key:prior_key=None@functools.wraps(fn)defwrapper(self,*args,**kwargs):nonlocalprior_version,prior_keyifprior_keyisNone:prior_key=self._get_key()ifprior_version!=self._version_counter:fn_cache.cache_clear()prior_version=self._version_counterprior_key=self._get_key()else:assertprior_key==self._get_key(), \
"ShapeEnv cache key changed without version being updated!"returnfn_cache(self,*args,**kwargs)else:@functools.wraps(fn)defwrapper(self,*args,**kwargs):nonlocalprior_versionifprior_version!=self._version_counter:fn_cache.cache_clear()prior_version=self._version_counterreturnfn_cache(self,*args,**kwargs)wrapper.cache_clear=fn_cache.cache_clearwrapper.cache_info=fn_cache.cache_info# type: ignore[attr-defined]returnwrapper# This is pretty similar to ShapeGuard but it also comes with a message,# and is exclusively used for things that MUST be true (unlike guards,# which can evaluate False, in which case you just choose not to use# a particular specialization)@dataclass(frozen=True)classRuntimeAssert:expr:sympy.Exprmsg:str=field(repr=False)stack:str=field(repr=False)classShapeGuardPrinter(StrPrinter):def__init__(self,symbol_to_source,source_ref,var_to_sources,):super().__init__()self.symbol_to_source=symbol_to_sourceself.source_ref=source_refself.var_to_sources=var_to_sourcesdef_print_Not(self,expr):return'not %s'%(self.parenthesize(expr.args[0],PRECEDENCE["Not"]))def_print_And(self,expr):returnself.stringify(expr.args," and ",PRECEDENCE["And"])def_print_Or(self,expr):returnself.stringify(expr.args," or ",PRECEDENCE["Or"])def_print_Symbol(self,expr)->str:assertisinstance(expr,sympy.Symbol),str(type(expr))defrepr_symbol_to_source():returnrepr({symbol:[s.name()forsinsources]forsymbol,sourcesinself.symbol_to_source.items()})assertself.symbol_to_source.get(expr),(f"{expr} (could be from {[s.name()forsinself.var_to_sources[expr]]}) "f"not in {repr_symbol_to_source()}. If this assert is failing, it could be ""due to the issue described in https://github.com/pytorch/pytorch/pull/90665")returnself.source_ref(self.symbol_to_source[expr][0])classLoggingShapeGuardPrinter(ShapeGuardPrinter):def__init__(self,var_to_sources):super().__init__(var_to_sources,lambdan:n.name(),var_to_sources)classDynamicDimConstraintPrinter(StrPrinter):""" Printer for dynamic dim constraints. - Instead of t.size()[d] it prints dynamic_dim(t, d) - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc. We use this to suggest code for specifying dynamic dim constraints. """def__init__(self,symbol_to_source,source_name_to_debug_name):super().__init__()self.symbol_to_source=symbol_to_sourceself.source_name_to_debug_name=source_name_to_debug_namedefprint_source(self,source)->str:ifself.source_name_to_debug_name:returnsource.name()returnf"dynamic_dim({source.base.name()}, {source.idx})"def_print_Symbol(self,expr)->str:assertisinstance(expr,sympy.Symbol),str(type(expr))assertself.symbol_to_source.get(expr),(f"Unknown symbol {expr} created by constraints solver")returnself.print_source(self.symbol_to_source[expr][0])def_print_Relational(self,expr):return'{}{}{}'.format(self.parenthesize(expr.lhs,precedence(expr)),expr.rel_op,self.parenthesize(expr.rhs,precedence(expr)))
[docs]classDimConstraints:""" Custom solver for a system of constraints on symbolic dimensions. Solutions are "static" values or simplified "dynamic" constraints. """def__init__(self,symbol_to_source,var_to_val,marked_dynamic,source_name_to_debug_name):# We try to solve systems of inequalities with 1 free variable.self._univariate_inequalities:Dict[sympy.Symbol,Set[sympy.Expr]]=defaultdict(set)# Among them, we prioritize solving for a free variable that has equalities.# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()# and removing a symbol from the former => removing it from the latter.self._symbols_with_equalities:Set[sympy.Symbol]=set()# A solution of a free variable with equalities becomes a substitution.# We use these substitutions to simplify other constraints.# NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.self._substitutions:Dict[sympy.Symbol,sympy.Integer]={}# In general, constraints may have // and % operations.# Of course, // can be expressed in terms of / and %.# Our inequality solver can handle / but not %. So we need to transform them away.# We do so by using the values of variables as hints to evaluate %.# For soundness we record additional congruence guards and solve them separately.self._var_to_val:Dict[sympy.Symbol,sympy.Integer]=var_to_valself._congruences:Set[sympy.Expr]=defaultdict(set)# We do not try to (directly) solve inequalities with > 1 free variables.# NOTE: free variables in these inequalities cannot also be in _substitutions.self._multivariate_inequalities:Set[sympy.Expr]=set()# We park external equalities between free variables here.self._symbolic_equivalences:List[Tuple[Source,sympy.Expr]]=[]# Solutions come in two forms:# - (static) specializations# - (dynamic) inequalities / congruencesself._static_results:Set[str]=set()self._dynamic_results:Set[str]=set()# printer for solutionsself._dcp=DynamicDimConstraintPrinter(symbol_to_source,source_name_to_debug_name)# inconsistencies found on substituting with concrete values / static solutionsself._inconsistencies:List[str]=[]# symbols that are marked dynamicself._marked_dynamic=marked_dynamic
[docs]defrewrite_with_congruences(self,s,expr):""" Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. This leaves rational operators (in particular of the form b / d) that our inequality solver can handle. We solve the added congruences separately (using our congruence solver, see below). """defmod_handler(*args):# Suppose that we have an expression of the form b % d with free variable s.# Using the value of s as a "hint," we can evaluate b % d to a value k.# Then we can rewrite b % d to k while adding the guard b % d == k.# NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF# the original expression always evaluates to a constant value (i.e., it does not vary with s).# In other words,# - solutions of s with the rewritten expression are guaranteed to also be solutions of s with# the original expression;# - while it may be possible to find solutions of s with the original expression that are not# solutions with the rewritten expression, in that case the original expression cannot evaluate# to the same value for all solutions of s.## Should we be worried about this incompleteness? No, because of the following reasons:# 1. It unblocks dramatic simplification that would not be otherwise possible with current tech# (i.e., "don't let perfect be the enemy of the good").# 2. We already have a tradition of using hints to add guards in the compiler for making progress.# 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards# we generate (or simplify to) seem to be of the form b % d == k where k is a constant.## Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2.# With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we# would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!base,divisor=argsbase,divisor=self.rewrite_with_congruences(s,base),self.rewrite_with_congruences(s,divisor)mod_reduced=base.subs(self._var_to_val)%divisor.subs(self._var_to_val)congruence=(base-mod_reduced)%divisorifcongruence!=0:self._congruences[s].add(congruence)returnmod_reduceddeffloor_div_handler(*args):# Suppose that we have an expression of the form b // d with free variable s.# Using the value of s, we can evaluate b % d to a value k.# Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k.# NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d# and eliminating b % d as above.base,divisor=argsbase,divisor=self.rewrite_with_congruences(s,base),self.rewrite_with_congruences(s,divisor)mod_reduced=base.subs(self._var_to_val)%divisor.subs(self._var_to_val)congruence=(base-mod_reduced)%divisorifcongruence!=0:self._congruences[s].add(congruence)return(base-mod_reduced)/divisorifexpr.has(Mod):expr=expr.replace(Mod,mod_handler)ifexpr.has(FloorDiv):expr=expr.replace(FloorDiv,floor_div_handler)returnexpr
[docs]defadd(self,expr)->bool:"""Add an expression to the set of constraints. Return whether the expression is a trivial constraint (i.e., an obvious tautology). """ifexpr==sympy.true:returnTrueorig_expr=exprorig_reduced=orig_expr.subs(self._var_to_val)# TODO(avik): https://github.com/pytorch/pytorch/issues/101093# It is possible that `expr` will fail the consistency check because of# precision errors. Specifically, on substituting its free symbols with# their concrete values, we might end up comparing floats. Until we have# a fix for this issue, we delay raising such failures. See solve().iforig_reduced==sympy.false:self._inconsistencies.append(f"{orig_expr} is inconsistent!")ifisinstance(expr,sympy.Ne):# we're not going to do anything useful with these, so drop themreturnFalsefree_symbols=expr.free_symbolsassertfree_symbols,f"Did not expect constraint with no free variables: {expr}"iflen(free_symbols)>1:# multivariate: record and move onself._multivariate_inequalities.add(expr)else:# univariate: can solve these immediatelys=next(iter(free_symbols))# eliminate // and % (see documentation of `rewrite_with_congruences` above)old_n_congruences=len(self._congruences[s])expr=self.rewrite_with_congruences(s,expr)new_n_congruences=len(self._congruences[s])ifexpr==sympy.true:returnold_n_congruences==new_n_congruencesreduced=expr.subs(self._var_to_val)ifreduced==sympy.false:self._inconsistencies.append(f"{expr}, obtained by rewriting {orig_expr} with congruences, ""is inconsistent!")ifisinstance(expr,sympy.Eq):# special status for symbols that have equalities (see `solve` below)self._symbols_with_equalities.add(s)self._univariate_inequalities[s].add(expr)returnFalse
[docs]defadd_equality(self,source,expr):"""Add an equality constraint"""ifexpr.is_number:# specialization, right hereself._static_results.add(f"{source.name()} == {expr}")else:# these will resolve to either specializations or dynamic equality constraintsself._symbolic_equivalences.append((source,expr))
def_reduce_congruences(self):reduced_congruences={}fors,congruencesinself._congruences.items():remainder_modulus_pairs=[]congruences_to_check=set()forcongruenceincongruences:base,divisor=congruence.args# We are given a congruence of the form base % divisor == 0 with a free variable s. So:# - we transform this into an equation of the form base = divisor * tmp;# - we solve this equation for s to get a linear solution with free variable tmp.tmp=sympy.Symbol("tmp",integer=True)symbol,solution=sympy.solve_linear(base-divisor*tmp,symbols=[s])# See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear# for how to interpret the results.ifs==symbol:# This means the solution is of the form s = modulus*tmp + remainder.modulus,remainder=sympy.polys.polytools.div(solution,tmp)ifisinstance(modulus,sympy.Integer)andisinstance(remainder,sympy.Integer):# Make sure 0 <= remainder <= modulus.remainder=remainder%modulusremainder_modulus_pairs.append((remainder,modulus))continue# This means that we did not get a unique solution to the equation.# No problem, we will check it.congruences_to_check.add(congruence)# Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i).# The solution will be a congruence of the form s = r mod m.# NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT.ifremainder_modulus_pairs:remainder,modulus=sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs)reduced_congruences[s]={(s-remainder)%modulus}substitution={s:modulus*sympy.Symbol("tmp",integer=True)+remainder}reduced_congruences[s].update(congruenceforcongruenceincongruences_to_checkifnotsympy.checksol(congruence,substitution))else:reduced_congruences[s]=congruences_to_checkreturnreduced_congruencesdef_raise_inconsistencies(self):ifself._inconsistencies:msg="\n".join(self._inconsistencies)self._inconsistencies.clear()raiseValueError(f"The following inconsistencies were found:\n{msg}")def_force_specialization(self,s):val=self._var_to_val[s]self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")self._substitutions[s]=valdef_specialize_divisor_symbols(self):forexprinself._multivariate_inequalities:foratominexpr.atoms(FloorDiv,Mod):_,divisor=atom.argsforsindivisor.free_symbols:self._force_specialization(s)multivariate_inequalities=self._multivariate_inequalitiesself._multivariate_inequalities=set()forexprinmultivariate_inequalities:self.add(expr.subs(self._substitutions))self._raise_inconsistencies()self._univariate_inequalities={s:exprsfors,exprsinself._univariate_inequalities.items()ifsnotinself._substitutions}self._congruences={s:congruencesfors,congruencesinself._congruences.items()ifsnotinself._substitutions}
[docs]defsolve(self,disable_congruences=True,disable_equivalences=True):"""Solve the system of constraint equations to find simplified constraints """self._raise_inconsistencies()# as long as there are symbols with equalities, solve for them# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)whileself._symbols_with_equalities:s=self._symbols_with_equalities.pop()exprs=self._univariate_inequalities.pop(s)solution=sympy.solvers.inequalities.reduce_inequalities(exprs,s)ifisinstance(solution,sympy.And):solution=next((argforarginsolution.argsifisinstance(arg,sympy.Eq)),solution)assertisinstance(solution,sympy.Eq),f"Expected an equality constraint for {s}, got {solution}"symbol,val=solution.argsassertsymbol==s,f"Expected a constraint on {s} instead of on {symbol}"# because this is univariate, the solution is a specializationself._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")# add this as a substitution to simplify other constraintsself._substitutions[s]=val# simplify multivariate inequalities: some of them will now become univariate!multivariate_inequalities=self._multivariate_inequalitiesself._multivariate_inequalities=set()forexprinmultivariate_inequalities:self.add(expr.subs(s,self._substitutions[s]))self._raise_inconsistencies()self._specialize_divisor_symbols()# solve linear congruences# NOTE(avik): We do not need to solve them for symbols that have already been specialized.reduced_congruences=self._reduce_congruences()fors,congruencesinreduced_congruences.items():forcongruenceincongruences:# any congruence that cannot be checked becomes a dynamic constraint as wellifsnotinself._substitutionsornotsympy.checksol(congruence,{s:self._substitutions[s]}):ifself._is_supported_congruence(congruence):base,divisor=congruence.argstmp_name=f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"tmp=sympy.Symbol(tmp_name,integer=True)fromtorch._dynamo.sourceimportConstantSourceself._dcp.symbol_to_source[tmp]=[ConstantSource(tmp_name)]r=try_solve(sympy.Eq(base,divisor*tmp),s)self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s,r[1])))elifdisable_congruences:self._force_specialization(s)self._univariate_inequalities.pop(s,None)# remaining symbols have only pure inequalities (no equalities)fors,exprsinself._univariate_inequalities.items():try:solution=sympy.solvers.inequalities.reduce_inequalities(exprs,s)# because this is univariate, the solution is a dynamic (range) constraintifisinstance(solution,sympy.Or):solution=next(iter(argforarginsolution.argsifarg.subs(self._var_to_val)))ifisinstance(solution,sympy.And):forarginsolution.args:self._dynamic_results.add(self._dcp.doprint(arg))else:self._dynamic_results.add(self._dcp.doprint(solution))except(NotImplementedError,AssertionError)ase:log.warning("Failed to reduce inequalities: %s",e)forexprinexprs:self._dynamic_results.add(self._dcp.doprint(expr))# simplify symbolic equivalences: some of them will now become specializations!symbolic_equivalences=self._symbolic_equivalencesself._symbolic_equivalences=[]forsource,exprinsymbolic_equivalences:ifdisable_equivalencesandnotself._is_supported_equivalence(expr):forsinexpr.free_symbols:self._force_specialization(s)sexpr=self._dcp._print_Symbol(s)self._dynamic_results={rforrinself._dynamic_resultsifsexprnotinr}self.add_equality(source,expr.subs(self._substitutions))# remaining symbolic equivalences become dynamic equality constraintsforsource,exprinself._symbolic_equivalences:self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}")
@classmethoddef_is_supported_equivalence(cls,expr):# Currently supported Dim ops are linear expressions with integer coefficients.# So check that expr only contains +, *, ints, and a single occurrence of a symbol.# (See also documentation of dynamic_shapes._DerivedDim.)ifisinstance(expr,(sympy.Add,sympy.Mul)):lhs,rhs=expr.argsreturn((cls._is_supported_equivalence(lhs)andisinstance(rhs,sympy.Integer))or(isinstance(lhs,sympy.Integer)andcls._is_supported_equivalence(rhs)))returnisinstance(expr,sympy.Symbol)@classmethoddef_is_supported_congruence(cls,congruence):base,divisor=congruence.args# Congruences that can be currently expressed with supported Dim ops are# of the form (x + a) % b == 0, where x is a Dim and a and b are constants.# This allows us to derive x as b*y - a for some Dim y.# (See also documentation of dynamic_shapes._DerivedDim.)ifisinstance(base,sympy.Add):lhs,rhs=base.argscond=((isinstance(lhs,sympy.Symbol)andisinstance(rhs,sympy.Integer))or(isinstance(lhs,sympy.Integer)andisinstance(rhs,sympy.Symbol)))else:cond=isinstance(base,sympy.Symbol)cond=condandisinstance(divisor,sympy.Integer)returncond
[docs]defforced_specializations(self):"""Returns a dictionary of the names of symbols to their specialized value """defdebug_name(src):name=src.name()ifself._dcp.source_name_to_debug_name:returnf"{self._dcp.source_name_to_debug_name[name]} = {name}"else:returnnamereturn{debug_name(self._dcp.symbol_to_source[s][0]):valfors,valinself._substitutions.items()ifsinself._marked_dynamic}
[docs]defremove_redundant_dynamic_results(self):"""Remove constraints of the form 2 <= dynamic_dim(...) as 2 is the default lower bound. """candidates_for_removal=[]dynamic_results=set()fordcinself._dynamic_results:# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).# There is no change in behavior since 2 is the default lower bound.dc_=re.sub(r"2 <= dynamic_dim(.+)",r"dynamic_dim\1",dc)ifdc!=dc_:candidates_for_removal.append(dc_)else:dynamic_results.add(dc_)fordcincandidates_for_removal:# remove dynamic_dim(t, 0) as a constraint when dynamic_dim(t, 0) also# appears as part of another constraintfound=Falseforother_dcindynamic_results:ifdcinother_dc:found=Trueifnotfound:dynamic_results.add(dc)self._dynamic_results=dynamic_results
[docs]defprettify_results(self,original_signature:inspect.Signature,constraint_violation_error=None,forced_specializations=None,):"""Format a message for constraint violation erros"""ifself._dcp.source_name_to_debug_name:deftransform(s):fork,vinself._dcp.source_name_to_debug_name.items():s=s.replace(k,v)returnsresults=defaultdict(dict)defflip(op):ifop=="<=":return">="ifop==">=":return"<="ifop=="<":return">"ifop==">":return"<"assertop=="=="returnopdefrelation_with_digit(expr,op,digit):ifop=="<=":results[expr]["max"]=digitelifop=="<":results[expr]["max"]=digit-1elifop==">=":results[expr]["min"]=digitelifop==">":results[expr]["min"]=digit+1else:assertop=="=="results[expr]["eq"]=digitforsinself._static_results.union(self._dynamic_results):t=transform(s)ift==s:continueleft,op,right=re.split(r"( == | <= | >= | < | > )",t)op=op.strip()ifop=="=="andleft==right:continueifright.isdigit():relation_with_digit(left,op,int(right))elifleft.isdigit():relation_with_digit(right,flip(op),int(left))else:assertop=="=="results[left]["eq"]=sympy.sympify(right)buf=""debug_names=set()ifforced_specializations:debug_names.update(k.split(" = ")[0]forkinforced_specializations.keys())buf+=(f"Specializations unexpectedly required ({', '.join(debug_names)})! ""For more information, run with TORCH_LOGS=\"+dynamic\".\n")fors,valinforced_specializations.items():buf+=f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n"dims=[]others=[]match=Noneifconstraint_violation_error:match=re.search(r"Constraints violated \((.*)\)",constraint_violation_error.args[0])ifmatchisnotNone:debug_names.update(match.expand(r'\1').split(', '))fork,cinsorted(results.items()):# if k not in debug_names:# continueif"eq"inc:other=c["eq"]ifisinstance(other,int):others.append(f"{k} = None # {other}")elifself._is_supported_equivalence(other):s=next(iter(other.free_symbols))ifsnotinresults:modulus,remainder=sympy.polys.polytools.div(other,s)c_min=c.get("min",2)min_=math.ceil((c_min-remainder)/modulus)c_max=c.get("max",sys.maxsize-1)max_=math.floor((c_max-remainder)/modulus)dims.append(f"{s} = Dim('{s}', min={min_}, max={max_}) # {c_min} <= {other} <= {c_max}")others.append(f"{k} = {other}")else:min_=c.get("min",None)ifmin_==2:min_=Nonemax_=c.get("max",None)ifmin_isnotNoneandmax_isnotNone:dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")elifmin_isnotNone:dims.append(f"{k} = Dim('{k}', min={min_})")elifmax_isnotNone:dims.append(f"{k} = Dim('{k}', max={max_})")else:dims.append(f"{k} = Dim('{k}')")buf+="\nSuggested fixes:\n "buf+="\n ".join(dims+others)returnbuf# Note: Model inputs are wrapped as LocalSource in dynamo.# LocalSource.name() wraps the name with L[""]. We use regular# expression to do the replacement to avoid traversing up# the source hierarchy manually.defextract_and_rewrite_local(dc):match=re.search(r"L\['(.+?)'\]",dc)ifmatchisNone:returnarg=match.expand(r'\1')dc=re.sub(r"L\['(.+?)'\]",r'\1',dc)returnarg,dcdefgroup(results,args_index):groups=defaultdict(list)fordcinresults:local=extract_and_rewrite_local(dc)iflocalisNone:# This can happen, e.g., with `assume_constant_result`.# In that case, we drop the constraint.# TODO(avik) Maybe we should generate an assertion here?continuearg,dc=localifarginargs_index:groups[args_index[arg]].append(dc)else:# This can happen, e.g., with decorators that change the signature.# In that case, we drop the constraint. Seems hard to do better. :/# TODO(avik) Maybe warn that `arg` in not in `signature`?continuesorted_groups=[]foridx,dcsinsorted(groups.items()):_,arg=idxsorted_groups.append((arg,sorted(dcs)))returnsorted_groupssignature=original_signature.replace(return_annotation=inspect.Signature.empty)args_index={}fori,arginenumerate(signature.parameters.keys()):args_index[arg]=(i,arg)defprint_results(grouped,indent,result_fn):nonlocalbufspace=Falseforarg,resultsingrouped:ifspace:buf+="\n"else:space=Truebuf+=f"\n{indent}# {arg}:"forresultinresults:buf+=f"\n{indent}{result_fn(result)}"buf=""ifforced_specializations:buf+=("Some dynamic dimensions need to be specialized because ""the constraints inferred for them are too complex to specify.\n")fors,valinforced_specializations.items():buf+=f" - {s}, which was marked dynamic, must be specialized to {val}.\n"indent=4*" "ifself._static_results:grouped_static_results=group(self._static_results,args_index)buf+="\nThe following dimensions have been specialized and CANNOT be dynamic."buf+=f"\n```\ndef specializations{str(signature)}:"print_results(grouped_static_results,indent,lambdaresult:f"assert {result}",)buf+="\n```\n"ifself._dynamic_results:grouped_dynamic_results=group(self._dynamic_results,args_index)buf+="\nThe following dimensions CAN be dynamic."buf+="\nPlease use the following code to specify the constraints they must satisfy:"buf+=f"\n```\ndef specify_constraints{str(signature)}:"buf+=f"\n{indent}return ["print_results(grouped_dynamic_results,indent*2,lambdaresult:f"{result},",)buf+=f"\n{indent}]\n```\n"returnbuf
TLS=threading.local()
[docs]classShapeEnv:# This is a wrapper over the actual __init__ function.## Where to add a new constructor parameter to ShapeEnv?# =====================================================# This __init__ function should be used only for parameters related to event recording.# These are parameters that we don't wish to pass down the road to new ShapeEnv instances# created from replaying events.## If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event# recording, do so in the _init function.def__init__(self,*,should_record_events:Optional[bool]=None,tracked_fakes:Optional[List[Any]]=None,**kwargs)->None:self._init(**kwargs)# Disable event recording when replaying.kwargs["should_record_events"]=Falsefromtorch.fx.experimental.validatorimporttranslation_validation_enabledself._translation_validation_enabled=translation_validation_enabled()# If not specified, enable event recording if both:# - Translation validation is on# - Translation validation bisection is not disabledself.should_record_events=(should_record_eventsifshould_record_eventsisnotNoneelse(self._translation_validation_enabledandnotconfig.translation_validation_no_bisect))# Enable event recording check if both:# - It should record events# - The recording check is enabledself.check_recorded_events=(self.should_record_eventsandconfig.check_shape_env_recorded_events)# This will make sure we only record the top-level function call.self.is_recording=notself.should_record_events# Keep track of the list of tracked fakes.self.tracked_fakes=tracked_fakes# List of events for reconstructing ShapeEnv at arbitrary points in time.self.events:List[ShapeEnvEvent]=([ShapeEnvEvent(ShapeEnv,kwargs=kwargs)]ifself.should_record_eventselse[])# Pro-tip: if you add new field to ShapeEnv, this affects some accept# tests. Accept their output with:## EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal#def_init(self,*,allow_scalar_outputs=True,allow_dynamic_output_shape_ops=True,# NB: These are legacy configuration that help us make good choices# when the constraint/dynamic dims are not explicitly passed to us.# Ideally we will fix all call sites to be explicit and not have# implicit choices, but this apparently was pretty involved.assume_static_by_default=False,# Note - On 0/1 specialization## The following options affect decisions we make about eager# specialization. Disabling them will increase trace time (as we do# more symbolic reasoning) and can also harm the quality of generated# code (because inductor may not be able to specialize for bounds# being equal--although if we later respecialize because of a guard,# your code may be just as good as it was before.)## When True, eagerly specialize input sizes which have 0/1.specialize_zero_one=True,# When True, assume input sizes which have the same size are# symbolically equal.duck_shape=True,# For debuggingco_fields=None,# XXX Add any new settings that could affect FakeTensor evaluation# to: torch._subclasses.fake_tensor._ShapeEnvSettings):# Not directly used by ShapeEnv; indirectly used by FakeTensorself.allow_scalar_outputs=allow_scalar_outputsself.allow_dynamic_output_shape_ops=allow_dynamic_output_shape_opsself.guards:List[ShapeGuard]=[]# Maps symbolic ints to their original concrete values# Currently populated from tensorsself.var_to_val:Dict[sympy.Symbol,sympy.Integer]={}# Maps symbolic ints to their min/max range. These ranges# are conservative: the int MUST fall in the range, but the# range may contain ints which may not actually appear in# practiceself.var_to_range:Dict[sympy.Symbol,ValueRanges]={}self.source_name_to_debug_name:Dict[str,str]={}self.var_to_sources:Dict[sympy.Symbol,List[Source]]={}self.var_to_stack:Dict[sympy.Symbol,CapturedTraceback]={}# Maps from sympy ints to expressions representing them# Populated from equality guards (i.e. a.shape[0] == b.shape[0])self.replacements:Dict[sympy.Symbol,sympy.Expr]={}# Set holds a % b expressions that evaluate to 0.self.divisible:Set[sympy.Expr]=set()# Set that holds "size-like" symbols. When we perform# "size-oblivious" tests, these can be assumed to be >= 2.self.size_like:Set[sympy.Symbol]=set()# Duck-shaping says that if two input tensors have the same size,# they get assigned the same symbolic variableself.val_to_var:Dict[int,sympy.Expr]={}ifspecialize_zero_one:self.val_to_var={0:sympy.Integer(0),1:sympy.Integer(1)}self.unbacked_symfloat_counter=itertools.count()self.unbacked_symint_counter=itertools.count()# Similar to guards, but these MUST evaluate to true and can# only be evaluated at runtime midway through (i.e., they always# involve unbacked symints)## For efficiency reasons, we index in the following way. Suppose you have# a runtime assert i0 + i1 <= s1. We pick the most recently allocated# symbol in the source expression and add the assert to the list for# that symbol e.g., {i1: [i0 + i1 <= s1]}.## We access the runtime asserts in two situations:## - When we are guarding on an expression, we will attempt to# statically evaluate it, in case the unbacked SymInts can# simplify away. If we have a runtime assert, we may be able# to discharge the guard entirely. We only need to attempt# runtime asserts that mention freevars of the expression in# question.## - When we are performing codegen (in Inductor for eager, or# when finalizing the export FX graph), we need to know what# extra runtime asserts to insert. Whenever an unbacked# SymInt comes into scope, all runtime asserts involving it# become eligible for insertion (so long as all of their other# free unbacked symbols are also in scope). We technically# can handle any choice of key by kicking inexpressible asserts# to the next unbacked symbol to wait on, but if we choose the# latest key, an assert will only show up at the moment when# we can actually codegen it.self.deferred_runtime_asserts:Dict[sympy.Symbol,List[RuntimeAssert]]={}# This exists so we can efficiently invalidate the cache (it's used as# part of the cache key); otherwise we'd have to iterate through# deferred_runtime_asserts to compute its lengthself.num_deferred_runtime_asserts=0self.assume_static_by_default=assume_static_by_defaultself.specialize_zero_one=specialize_zero_oneself.duck_shape=duck_shapeself.log=logself.log.debug("create_env")self.frozen=Falseself.dim_constraints:Optional[DimConstraints]=Noneself.counter=collections.Counter()# Mapping from sympy.Symbol to the number of guards which mention this# symbolself.symbol_guard_counter=collections.Counter()# A selection of important fields on co_field; solely used for# signpost_eventself.co_fields=co_fieldsifco_fieldselse{}# Version counter used to invalidate cached valuesself._prev_cache_key=self._get_key()self._version_counter=0# Cache for FX nodes.# Maps an already built node a tuple of:# 1. node's target# 2. list of arguments# This drastically reduces the size of the FX graph, avoiding# duplicated nodes.self.fx_node_cache:Dict[Tuple[Callable,Tuple[Any,...]],torch.fx.Node]={}self.source_to_symbol:Dict[str,sympy.Symbol]={}fromtorch.fx.experimental.validatorimporttranslation_validation_enabledself._translation_validation_enabled=translation_validation_enabled()ifself._translation_validation_enabled:fromtorch.fx.experimental.validatorimportTranslationValidatorself.validator=TranslationValidator()self.graph=torch.fx.Graph()# Create an output graph and start inserting before that.# This is needed when 'deepcopy'-ing this object.self.graph.inserting_before(self.graph.output(None))# Mapping of each node name to the node itself.## This is useful for matching an FX node from a recorded ShapeEnv.graph# to the FX node of the ShapeEnv we are running the event on.## Whenever you add a node to self.graph, you must add a mapping to this# variable. Otherwise, the built FX graph on the replayed ShapeEnv will# not be valid.self.name_to_node:Dict[str,torch.fx.Node]={}
[docs]defcheck_equal(self,other:"ShapeEnv")->None:"""Compare another ShapeEnv for equivalence """# ShapeEnv fields that are not relevant for the outcome of# ShapeEnv.produce_guards call:# - Debugging variables# - Translation validation related variables# - Events recording related variablesnon_state_variable_names=("counter","log","var_to_stack","fx_node_cache","graph","validator","check_recorded_events","should_record_events","is_recording","tracked_fakes","events","source_name_to_debug_name","_prev_cache_key","_version_counter",)# Mapping of the value of each to-be-compared field into the values that# should actually be compared.## You should modify this if, for example, the field that holds state and# debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr)# and the stack when it was added to the set of guards. In order to compare# it, we throw away the stack information.defmap_value(key:str,value:Any)->Any:ifkeyin("unbacked_symfloat_counter","unbacked_symint_counter"):fromcopyimportcopy# For itertools.count(), we compare the next integer returned# by the count iterators. Not that we need to copy the iterator# first. Otherwise we are mutating the object.returnnext(copy(value))elifkey=="guards":# Transform the list of ShapeGuard into a list of expressions.return[g.exprforginvalue]elifkey=="deferred_runtime_asserts":# Transform the list of RuntimeAsserts into a list of expressions.return{s:[ra.exprforrainras]fors,rasinvalue.items()}elifkey=="name_to_node":# Compare just the set of keys is the same.returnset(value.keys())elifkey=="symbol_guard_counter":# Skip this for comparisonsreturnNonereturnvalueshape_env_check_state_equal(self,other,non_state_variable_names,map_value)
def_snapshot_tracked_fakes(self)->Optional[List[Any]]:ifself.tracked_fakesisNone:returnNonefromtorch._dynamo.variables.builderimportTrackedFakedefmaybe_transform_fake(fake:TrackedFake):inner_fake=fake.fake \
ifisinstance(fake.fake,torch.SymInt) \
elseFakeTensorMeta.from_fake(fake.fake)# Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a# FakeTensorMeta for two reasons:# 1. this is all the information we need when recording ShapeEnvEvents.# 2. it works even if each TrackedFake changes its metadata.returnTrackedFake(inner_fake,fake.source,fake.symbolic_context)# type: ignore[arg-type]return[maybe_transform_fake(fake)forfakeinself.tracked_fakes]def_last_event_index(self)->int:returnlen(self.events)-1@contextmanagerdef_recording(self):self.is_recording=Truetry:yieldfinally:self.is_recording=False
[docs]@record_shapeenv_event()deffreeze(self):"""Freeze this ShapeEnv to stop accumulating guards A frozen ShapeEnv will ignore any further guards generated on it and only emit a warning which may lead to accuracy problems. """self.frozen=True
def_create_symbol_for_source(self,source:Source)->Optional[sympy.Symbol]:ifnotself._translation_validation_enabled:returnNonesrcname=source.name()ifsourcenotinself.source_to_symbol:self.source_to_symbol[srcname]=sympy.Symbol(srcname,integer=True)returnself.source_to_symbol[srcname]def_add_z3var(self,symbol:sympy.Symbol,type:Type)->None:ifself._translation_validation_enabled:self.validator.add_var(symbol,type)def_add_target_expr(self,expr)->None:ifself._translation_validation_enabled:self.validator.add_target_expr(expr)def_add_assertion(self,expr)->None:ifself._translation_validation_enabled:self.validator.add_assertion(expr)def_check_translation_validate(self)->None:ifself._translation_validation_enabled:self.validator.validate()@record_shapeenv_event()def_create_fx_call_function(self,op:Callable,args:Tuple,)->Tuple[Optional[torch.fx.Node],bool]:# Cache this tuple in order to avoid duplicated nodes.node_key=(op,args)# Flags whether the returned node was cached or not.fresh=Falseifself._translation_validation_enabledandnode_keynotinself.fx_node_cache:fromtorch.fx.experimental.validatorimportz3op# Presence of None in the arguments implies that we should ignore this operation.ifany(aisNoneforainargs):# We check if we are not mixing SymNode that should not be ignored# (fx_node is not None) with those that should (fx_node is None).assertall(notisinstance(a,torch.fx.Node)forainargs)returnNone,freshfresh=Truelifted_op=z3op(op,self.validator)# If translation validation is enabled, all arguments must have its# own FX node.assertall(aisnotNoneforainargs),f"missing arg in FX graph ({op.__name__}): {args}"node=self.fx_node_cache[node_key]=self.graph.call_function(lifted_op,args)self.name_to_node[node.name]=nodereturnself.fx_node_cache.get(node_key,None),freshdef_create_fx_placeholder_and_z3var(self,symbol:sympy.Symbol,type:Type,)->Optional[torch.fx.Node]:ifnotself._translation_validation_enabled:returnNonenode_key=(self.graph.placeholder,(symbol,))# Check if we haven't added this symbol already.# If so, skip the placeholder creation, as it# generates invalid Python code.ifnode_keynotinself.fx_node_cache:# Add a Z3 variable according to 'type'.self._add_z3var(symbol,type)# Create the FX placeholder out of a mangled name.mangled_name=re.sub(r'[^a-zA-Z0-9]','_',re.sub(r'[()]','',symbol.name))node=self.fx_node_cache[node_key]=self.graph.placeholder(mangled_name)self.name_to_node[node.name]=node# Attach the 'symbol' to the placeholder so that we can retrieve# the Z3 variable later.node.meta["symbol"]=symbolreturnself.fx_node_cache[node_key]def_remove_fx_node(self,node:Optional[torch.fx.Node])->None:ifself._translation_validation_enabledandnodeisnotNone:self.name_to_node.pop(node.name)self.graph.erase_node(node)def_add_fx_node_metadata(self,node:torch.fx.Node)->None:fromtorch._dynamo.utilsimportget_current_nodeifself.should_record_events:node.meta[SHAPEENV_EVENT_KEY]=self._last_event_index()node.meta[CURRENT_NODE_KEY]=get_current_node()def_suppress_guards_tls(self):returngetattr(TLS,"suppress_guards",False)@record_shapeenv_event()def_suppress_guards_enter(self):TLS.suppress_guards=True@record_shapeenv_event()def_suppress_guards_exit(self):TLS.suppress_guards=False
[docs]@contextmanagerdefsuppress_guards(self):"""Context manager to ignore all guards generated inside"""self._suppress_guards_enter()try:yieldfinally:self._suppress_guards_exit()
def_get_key(self):""" Defines the current "state" of the guards we've accumulated in this ShapeEnv. Determines when we need to invalidate our cache """return(len(self.replacements),len(self.divisible),self.num_deferred_runtime_asserts)def_update_version_counter(self):# The shape environment is queried orders of magnitude more often than# it is changed, so we summarise the cache key into a linearly# increasing version counter which is cheaper to check in _lru_cache# Only update version counter if the state actually changedcur_key=self._get_key()ifself._prev_cache_key!=cur_key:self._prev_cache_key=cur_keyself._version_counter+=1def_produce_dyn_sizes(self,ex_size:Sequence[int],source:Source,symbolic_context:SymbolicContext)->List[sympy.Expr]:returnself._produce_dyn_sizes_from_int_tuple(tuple(ex_size),source,symbolic_context)def_produce_dyn_sizes_from_int_tuple(self,tensor_size:Tuple[int],source:Source,symbolic_context:SymbolicContext,)->List[sympy.Expr]:assertall(notis_symbolic(val)forvalintensor_size),f"Expect size to be a plain tuple of ints but got {tensor_size}"fromtorch._dynamo.sourceimportTensorPropertySource,TensorProperty_assert_symbol_context(symbolic_context)dynamic_dims=symbolic_context.dynamic_sizesconstraint_dims=symbolic_context.constraint_sizessize=[]fori,valinenumerate(tensor_size):size.append(self.create_symbol(val,TensorPropertySource(source,TensorProperty.SIZE,i),dynamic_dims[i],constraint_dims[i],symbolic_context=symbolic_context))returnsize
[docs]defcreate_symbolic_sizes_strides_storage_offset(self,ex:torch.Tensor,source:Source,*,symbolic_context:Optional[SymbolicContext]=None,):""" Returns a list of symbolic sizes and strides for the given tensor. We try our best to express stride in terms of the sizes, so as to not introduce new symbolic variables. """# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").# We create symbols in shape_env using the backed hints behind SymInt.# Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.# produce_guards will trigger specializations on the outer stuff# Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().## It's probably good for now but it's important to note that this approach has implications for# the original shape_env when checking guards in different order.# Example:# ---------# Consider a function "opt_f" as shown below:# @torch.compile()# def opt_f(x: bool, y: Tensor):# if x == True:# return y + torch.randn([4])# else:# return y# Depending on the sequence of calls, we might install two different sets of guards:# 1. opt_f(False, y):# - "x == False" (always works for any size y)# 2. opt_f(True, y):# - Triggers recompilation and results in guards like:# - "x == True and y.size(0) == 4"# - (or "y.size(0) == 4 and x == True")# The order of checking the guards matters. In this specific example:# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,# we may have an unnessary shape speciliazation for y.defmaybe_specialize_sym_int_with_hint(maybe_sym)->int:assertisinstance(maybe_sym,(int,torch.SymInt))ifis_symbolic(maybe_sym):assertmaybe_sym.node.shape_envisnotself, \
"expect the symbol is created from an shape env other than current one."returnmaybe_sym.node.require_hint()returnmaybe_symex_size=tuple(maybe_specialize_sym_int_with_hint(sz)forszinex.size())ex_stride=tuple(maybe_specialize_sym_int_with_hint(sd)forsdinex.stride())ex_storage_offset=maybe_specialize_sym_int_with_hint(ex.storage_offset())returnself._create_symbolic_sizes_strides_storage_offset(ex_size,ex_stride,ex_storage_offset,[_is_dim_dynamic(ex,i)foriinrange(ex.dim())],source,symbolic_context=symbolic_context,)
@record_shapeenv_event()def_create_symbolic_sizes_strides_storage_offset(self,ex_size:Sequence[int],ex_stride:Sequence[int],ex_storage_offset:int,is_dim_dynamic:Sequence[bool],source:Source,*,symbolic_context:Optional[SymbolicContext]=None,):dim=len(ex_size)# Reimplement the legacy behaviorifsymbolic_contextisNone:constraint_dims=[None]*dimdynamic_dims=[]foriinrange(dim):# NB: This is encapsulation breaking! Legacy behavior was# bad.ifis_dim_dynamic[i]:r=DimDynamic.DYNAMICelifself.assume_static_by_default:r=DimDynamic.STATICelse:r=DimDynamic.DUCKdynamic_dims.append(r)dynamic_dims=[DimDynamic.DUCK]*dim# symbolic_context is None - set onesymbolic_context=StatelessSymbolicContext(dynamic_sizes=dynamic_dims,constraint_sizes=constraint_dims)# We got a StatelessSymbolicContext_assert_symbol_context(symbolic_context)constraint_dims=symbolic_context.constraint_sizesdynamic_dims=symbolic_context.dynamic_sizes# TODO: make this configurable from outside symbolic_context; we made a symbolic_context# decision here where if all sizes are static, we are going to# specialize all of the inner strides/offset too. We don't have to# do this, and arguably we should ALWAYS allow for dynamic offset,# this is cheap.# TODO: This should be DYNAMIC, using DUCK for BCdynamic_strides_offset=DimDynamic.STATICifall(r==DimDynamic.STATICforrindynamic_dims)elseDimDynamic.DUCKassertlen(dynamic_dims)==dim,f"{len(dynamic_dims)} != {dim}"assertlen(constraint_dims)==dimfromtorch._dynamo.sourceimportTensorPropertySource,TensorPropertysize:List[sympy.Expr]=self._produce_dyn_sizes_from_int_tuple(ex_size,source,symbolic_context)stride:List[Optional[sympy.Expr]]=[None]*len(size)fori,valinenumerate(ex_stride):ifvalin(0,1):stride[i]=sympy.Integer(val)whileany(xisNoneforxinstride):candidates={ex_size[i]*ex_stride[i]:size[i]*stride[i]foriinrange(len(size))ifstride[i]isnotNoneandex_stride[i]>=0}# iterate over unbound strides in sorted orderdef_nested_int_aware_sort(tup):return(# Order nested ints by their coefficients.# 1 here to order nested ints after non-nested-ints.(1,tup[0].node.nested_int_coeff(),tup[1])ifis_nested_int(tup[0])else(0,*tup))val_list=sorted([(ex_stride[i],i)foriinrange(len(stride))ifstride[i]isNone],key=_nested_int_aware_sort,)for_,iinval_list:ifstride[i]isNoneandex_stride[i]incandidates:stride[i]=candidates[ex_stride[i]]candidates[ex_size[i]*ex_stride[i]]=size[i]*stride[i]ifany(xisNoneforxinstride):# bind the smallest unbound stride to a new variableval,i=min([(ex_stride[i],i)foriinrange(len(stride))ifstride[i]isNone],key=_nested_int_aware_sort)stride[i]=self.create_symbol(val,TensorPropertySource(source,TensorProperty.STRIDE,i),dynamic_dim=dynamic_strides_offset,constraint_dim=None,symbolic_context=symbolic_context,)assertall(xisnotNoneforxinstride)sym_sizes=[self.create_symintnode(sym,hint=hint,source=TensorPropertySource(source,TensorProperty.SIZE,i),)fori,(sym,hint)inenumerate(zip(size,ex_size))]sym_stride=[]fori,stride_exprinenumerate(stride):# NB: Don't duck size the stride; instead use the expression# we computedassertstride_exprisnotNonesym_stride.append(self.create_symintnode(stride_expr,hint=ex_stride[i],source=TensorPropertySource(source,TensorProperty.STRIDE,i)))sym_storage_offset=self.create_symintnode(self.create_symbol(ex_storage_offset,TensorPropertySource(source,TensorProperty.STORAGE_OFFSET),dynamic_dim=dynamic_strides_offset,constraint_dim=None,symbolic_context=symbolic_context),hint=ex_storage_offset,source=TensorPropertySource(source,TensorProperty.STORAGE_OFFSET))returntuple(sym_sizes),tuple(sym_stride),sym_storage_offset
[docs]@record_shapeenv_event()defcreate_symintnode(self,sym:"sympy.Expr",*,hint:Optional[int],source:Optional[Source]=None,):"""Create a SymInt value from a symbolic expression If you know what the current hint value of the SymInt to be created is, pass it into hint. Otherwise, pass None and we will make our best guess """source_name=source.name()ifsourceelseNoneifself._translation_validation_enabledandsourceisnotNone:# Create a new symbol for this source.symbol=self._create_symbol_for_source(source)assertsymbolisnotNone# Create a new FX placeholder and Z3 variable for 'symbol'.fx_node=self._create_fx_placeholder_and_z3var(symbol,int)# Add an equality assertion for the newly created symbol and 'sym'.self._add_assertion(sympy.Eq(symbol,sym))else:fx_node=Noneifisinstance(sym,sympy.Integer):ifhintisnotNone:assertint(sym)==hintout=int(sym)else:out=SymInt(SymNode(sym,self,int,hint,fx_node=fx_node))returnout
[docs]@record_shapeenv_event()defcreate_unspecified_symint_and_symbol(self,value,source,dynamic_dim):"""Create a SymInt wrapping a new unspecified symbol"""returnself.create_symintnode(self.create_unspecified_symbol(value,source=source,dynamic_dim=dynamic_dim,),hint=value,source=source,)
[docs]defcreate_symboolnode(self,sym:"sympy.Expr"):"""Create a SymBool object from a sympy boolean expression"""# This function is only being used in serialization, so we do not track it# for validation.returnSymBool(SymNode(sym,self,bool,None))
[docs]@record_shapeenv_event()defcreate_unbacked_symfloat(self):"""Create a symbolic float without a hint value """symbol:sympy.Symbol=sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")self.counter["create_unbacked_symbol"]+=1self.var_to_stack[symbol]=CapturedTraceback.extract(skip=1)vr=self.var_to_range[symbol]=ValueRanges.unknown()# Create a new FX placeholder and Z3 variable for 'symbol'.fx_node=self._create_fx_placeholder_and_z3var(symbol,float)self._log_create_unbacked_symbol("create_unbacked_symfloat",symbol,vr)returnSymFloat(SymNode(symbol,self,float,None,fx_node=fx_node))
[docs]@record_shapeenv_event()defcreate_unbacked_symint(self):"""Create a symbolic integer without a hint value """symbol:sympy.Symbol=sympy.Symbol(f"u{next(self.unbacked_symint_counter)}",integer=True)self.counter["create_unbacked_symbol"]+=1self.var_to_stack[symbol]=CapturedTraceback.extract(skip=1)vr=self.var_to_range[symbol]=self._default_unspecified_value_range()# Create a new FX placeholder and Z3 variable for 'symbol'.fx_node=self._create_fx_placeholder_and_z3var(symbol,int)self._log_create_unbacked_symbol("create_unbacked_symint",symbol,vr)returnSymInt(SymNode(symbol,self,int,None,fx_node=fx_node))
[docs]defis_unbacked_symint(self,symbol:sympy.Symbol)->bool:"""Check if a sympy symbol matches the naming convention for unbacked symbols """# NB: keep synced with free_unbacked_symbolsreturnstr(symbol).startswith("u")
[docs]@record_shapeenv_event()defcreate_unbacked_symbool(self):"""Create a symbolic boolean without a hint value """symbol:sympy.Symbol=sympy.Symbol(f"u{next(self.unbacked_symint_counter)}",integer=True)self.counter["create_unbacked_symbol"]+=1self.var_to_stack[symbol]=CapturedTraceback.extract(skip=1)vr=self.var_to_range[symbol]=ValueRanges(0,1)# Create a new FX placeholder and Z3 variable for 'symbol'.fx_node=self._create_fx_placeholder_and_z3var(symbol,bool)self._log_create_unbacked_symbol("create_unbacked_symbool",symbol,vr)returnSymBool(SymNode(sympy.Eq(symbol,1),self,bool,None,fx_node=fx_node))
[docs]@record_shapeenv_event()defcreate_unspecified_symbol(self,val:Union[int,SymInt],source:Source,dynamic_dim:DimDynamic=DimDynamic.DUCK,constraint_dim:DimConstraint=None,# NB: includes None)->"sympy.Expr":"""Create a symbol with an unspecified value Compared to standard symbols we do not assume the value is positive, nor do we specialze on zero or one values. """# 'positive' is None for unspecified symbols, since we can't# assume that it will be neither positive nor negative.# We don't want to specialize zero one val for unspecified symbol# so that we can always get a new symbol despite val.returnself.create_symbol(val,source,dynamic_dim,constraint_dim,positive=None,do_not_specialize_zero_one=True,symbolic_context=None)
[docs]@record_shapeenv_event()defcreate_symbol(self,val:int,source:Source,dynamic_dim:DimDynamic=DimDynamic.DUCK,constraint_dim:DimConstraint=None,# NB: includes Nonepositive:Optional[bool]=True,do_not_specialize_zero_one:bool=False,symbolic_context=None,)->"sympy.Expr":"""Create a new symbol which is tracked by this ShapeEnv """# see note [Tensor Fakification and Symbol Caching]source_name=source.name()if(isinstance(symbolic_context,StatefulSymbolicContext)andid(self)notinsymbolic_context.shape_env_to_source_to_symbol_cache):symbolic_context.shape_env_to_source_to_symbol_cache[id(self)]={}if(isinstance(symbolic_context,StatefulSymbolicContext)andsource_nameand(source_nameinsymbolic_context.shape_env_to_source_to_symbol_cache[id(self)])):returnsymbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name]ifdo_not_specialize_zero_one:specialize_zero_one=Falseelse:specialize_zero_one=self.specialize_zero_oneassertisinstance(source,Source),f"{type(source)}{source}"assertnot(positiveandval<0),f"positive set for negative value: {val}"# It's always sound to allocate a symbol as DYNAMIC. If the user# constrained the symbol, force the symbolic_context to DYNAMIC, because our# constraint code will do weird stuff if, e.g., it's duck shapedifconstraint_dimisnotNone:dynamic_dim=DimDynamic.DYNAMICifdynamic_dimisDimDynamic.STATIC:out=sympy.Integer(val)ifisinstance(symbolic_context,StatefulSymbolicContext)andsource_name:symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name]=outreturnoutelifdynamic_dimisDimDynamic.DUCK:# duck_shape can be used to globally turn off duck shaping, even# if it was requestedduck=self.duck_shapeelifdynamic_dimisDimDynamic.DYNAMIC:duck=Falseelse:raiseAssertionError(f"unhandled dynamic_dim {dynamic_dim}")ifvalin(0,1)andspecialize_zero_one:r=self.val_to_var[val]elifnotduckorvalnotinself.val_to_var:# If we're not duck shaping, we always create a new symbol# Even if we're duck shaping, if we haven't seen this particular# value before, we also create a new symbolsympy_expr=sympy.Symbol(f"s{len(self.var_to_val)}",positive=positive,integer=True)# We always associate vars to valsifisinstance(val,int):self.var_to_val[sympy_expr]=sympy.Integer(val)else:# Only used for jagged layout nested tensorsself.var_to_val[sympy_expr]=SingletonInt(val.node.nested_int(),coeff=val.node.nested_int_coeff())# Do the appending later, because we always want to populate thisself.var_to_sources[sympy_expr]=[]# Create a Z3 variable for the new symbol.self._add_z3var(sympy_expr,int)ifduck:# Make sure to reuse this symbol for subsequent duck shapingself.val_to_var[val]=sympy_exprifisinstance(val,int):ifpositive:# Add assertions for the newly created symbolsself._add_assertion(sympy_expr>1)# Apply default range, which assumes not zero-oneself.var_to_range[sympy_expr]=self._default_value_range()else:self.var_to_range[sympy_expr]=self._default_unspecified_value_range()# Small performance optimization: if we have a min-max constraint,# we can proactively narrow to that rangeifisinstance(constraint_dim,StrictMinMaxConstraint):assertnotduckself.var_to_range[sympy_expr]&=constraint_dim.vrvr=self.var_to_range[sympy_expr]ifvalnotinvr:raiseConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")range_str=f"[{vr.lower}, {vr.upper}]"else:# Skip var_range logic for SingletonInt# Only used for jagged layout nested tensorsrange_str=""r=sympy_expris_debug=(config.extended_debug_create_symbolisnotNoneandstr(sympy_expr)inconfig.extended_debug_create_symbol.split(','))fsummary,maybe_user_loc,maybe_extra_debug=self._get_stack_summary(is_debug)self.log.info("create_symbol %s = %s for %s%s%s (%s)%s",sympy_expr,val,source.name(),range_str,maybe_user_loc,format_frame(fsummary),maybe_extra_debug,stack_info=is_debug)self.counter["create_symbol"]+=1else:# This implements duck-shaping: input sizes that match are assigned# the same symintr=self.val_to_var[val]self.log.debug("create_symbol %s duck sized %s",r,source.name())ifisinstance(r,sympy.Symbol):r_sources=self.var_to_sources[r]r_sources.append(source)ifnotsource.is_ephemeral()andr_sources[0].is_ephemeral():# prefer non-ephemeral source first since it may be guarded on laterr_sources[0],r_sources[-1]=r_sources[-1],r_sources[0]# This ensures we get zeros in symbol_guard_counts, which makes# some queries simpler (since we will accumulate mass on 0 this# way)self.symbol_guard_counter[r]=0ifisinstance(symbolic_context,StatefulSymbolicContext)andsource_name:symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name]=rreturnr
def_debug_name(self,source):src_name=source.name()returnself.source_name_to_debug_name.get(src_name,src_name)def_render_range_for_constraint_violation(self,source,c):ifisinstance(c,StrictMinMaxConstraint):lower,upper=c.vr.lower,c.vr.upperdefault=self._default_value_range()iflower<=default.lower:lower=Noneifupper>=default.upper:upper=Nonec_render=f"{self._debug_name(source)} = {source.name()} in the specified range"iflowerisnotNoneandupperisnotNone:c_render+=f" {lower} <= {self._debug_name(source)} <= {upper}"eliflowerisNoneandupperisnotNone:c_render+=f" {self._debug_name(source)} <= {upper}"eliflowerisnotNoneandupperisNone:c_render+=f" {lower} <= {self._debug_name(source)}"returnc_renderreturnc.render(source)
[docs]defproduce_guards(self,placeholders,sources,source_ref=lambdan:n.name(),*,input_contexts:Optional[DimList[SymbolicContext]]=None,# Encodes user-specified input shape equations of the form s = s' and s = fn(s').# (See docs on EqualityConstraint for details of the encoding.)equalities_inputs:Optional[EqualityConstraint]=None,_simplified=False,# Indicates if we should produce guards for known static values.ignore_static=True,)->List[str]:""" Generates a list of guards strings which, when evaluated in a context that defines tensors for all the sources, returns True or False depending on if the guards in the list evaluated to True or not. Primarily used by Dynamo, but this is also helpful for manual testing of guards (see evaluate_guards_for_args) For convenience in testing, a source is allowed to be a str, in which case we will assume it is a LocalSource simplified lets you omit duck sizing, equality and 0/1 guards. This is useful for testing when you don't care about the boilerplate guards, and it may be helpful for user output too (be careful though; some equality guards are nontrivial! It would be nice to get simplified output to print them too). It's private because it's not intended for normal use """self.log.info("produce_guards")# Check if we get to the same ShapeEnv state by replaying the recorded events.# This will create a new ShapeEnv instance, and call all recorded function# calls on this new instance. Finally, it will check whether this new instance# has equal state.## It's important that we do it in the begining of this function, since it modifies# self.dim_constraints through its execution. Changes that happen in this method# aren't interesting, since this is the function call we wish to reproduce at the# end. If we wish to simply reproduce ShapeEnv instances even after this call,# this method should also be recorded.ifself.check_recorded_events:shape_env=replay_shape_env_events(self.events)self.check_equal(shape_env)assertlen(placeholders)==len(sources),f"len({placeholders}) != len({sources})"Tensorlike=(torch.Tensor,FakeTensorMeta)def_create_no_constraints_context(t):returnStatelessSymbolicContext(# Ignored; only the constraints part is relevant below.dynamic_sizes=[DimDynamic.DYNAMIC]*t.dim(),constraint_sizes=[None]*t.dim())# Expand optional inputs, or verify invariants are upheldifinput_contextsisNone:input_contexts=[_create_no_constraints_context(t)ifisinstance(t,Tensorlike)elseNonefortinplaceholders]else:assertlen(input_contexts)==len(placeholders)fori,(t,context)inenumerate(zip(placeholders,input_contexts)):ifisinstance(t,Tensorlike):ifcontextisNone:input_contexts[i]=_create_no_constraints_context(t)else:assertisinstance(t,(SymInt,int))assertnotisinstance(context,list)# It took a lot of sweat to figure out the algorithm here. Let's# explain how it works.## The ShapeEnv lifecycle looks something like this:## - For each input, you either generate a fresh Sympy symbol (s0) to# represent its value (a binding site), or you reuse some# preexisting symbol or expression, skipping the symbol allocation# (e.g., duck sizing to a preexisting symbol, or expressing a# stride as a multiplication of a separate stride and size.)# Naively, you might expect to bind a fresh Sympy symbol for# every input, but this is fairly wasteful as most of these# symbols immediately simplify away, and if you don't eagerly# specialize, e.g., 0/1 symbols, you end up with very complicated# expressions that are not optimizable in practice.## - You perform some compute on these symbols, occasionally# introducing guards on boolean expressions on these symbols.# In particular, whenever we guard on equality (_maybe_guard_rel),# we can simplify shapes; e.g., when s0 == s1 * 2, we can now# replace all occurrences of s0 with s1 * 2. Sometimes, a# boolean expression evaluation doesn't introduce a guard, as# the guard is already entailed by the simplifications we have# applied.## - In the end, you have a bunch of replacements (saying how to# simplify shapes) and a bunch of guards (all the equality guards# are trivial, because they're covered by the replacements).## From the ShapeEnv, we must generate a Python expression that, when# evaluated on a set of inputs, tells us whether or not these boolean# expressions would have evaluated in the same way. However,# we cannot easily compute this, as we elide recording boolean# expressions when we think they are vacuously true. Thus, we seek# an approximation: we must generate an expression, if true, would have# produced an "equivalent" ShapeEnv, which would answer guard# expressions in the same way.## Our notion of equivalence is a bit subtle. For example, consider# the ShapeEnv created from an input of size (5, 4) versus (4, 4)# (no other guards.) Duck sizing would generate (s0, s1) in the first# case but (s0, s0) in the second. We do NOT assume that size# variables are disjoint; so in fact a graph that assumes the input# could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not# vice versa. However, consider an analogous case (1,) versus (2,).# Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT# subsume the (1,) graph because we assume that any size variables# is NOT 0/1 (and make simplifications according to this; e.g., if# we queried s0 == 0, we would immediately return False without# returning a guard.)## So, it is perhaps easier to flip things on their head: the guard# expressions we generate here say what simplifications are valid,# and what are not. Below, we explain each of the guard expressions# we generate# TODO: Make this more efficient by binding all the size/stride/offsets# to locals before performing tests on them.fromtorch._dynamo.sourceimportTensorPropertySource,TensorProperty,NegateSource# Actual codegen must be delayed as we don't necessarily know what# the symbol mapping isinput_guards=[]symbol_to_source=collections.defaultdict(list)symbol_to_constraints=collections.defaultdict(set)constraint_violations:List[Tuple[bool,Callable[[],str]]]=[]defrecord_constraint_violation(warn_only,debug_name,msg,hint=None):constraint_violations.append((warn_only,debug_name,lambda:f"{msg}{hint()}"ifhintelsemsg))defis_dim(src):returnisinstance(src,TensorPropertySource)andsrc.propisTensorProperty.SIZEifequalities_inputs:source_index={}fori,srcinenumerate(sources):source_index[src.name()]=idefget_expression(tensor_dim_src):fake=placeholders[source_index[tensor_dim_src.base.name()]]symint=fake.shape[tensor_dim_src.idx]ifisinstance(symint,torch.SymInt):returnsymint.node.exprelse:asserttype(symint)isint,f"Expected int, got {type(symint)}"returnsymintforsrc1,src2inequalities_inputs.source_pairs:expr1,expr2=get_expression(src1),get_expression(src2)# Check whether given input shape values satisfy a specified equation s = s'.# - Raise when the equation was violated by the given input shape values.# - Otherwise issue a guard to constrain them.concrete_val=self.evaluate_expr(sympy.Eq(expr1,expr2))ifnotconcrete_val:raiseConstraintViolationError(f"{src1.name()} = {expr1.subs(self.var_to_val)}"" is not equal to "f"{src2.name()} = {expr2.subs(self.var_to_val)}")forsrc,root,fninequalities_inputs.derived_equalities:expr1=get_expression(src)# recall that root is either a phantom symbol or an input sourceexpr2,debug_name=((root,self.var_to_sources[root][0].name())ifisinstance(root,sympy.Symbol)else(get_expression(root),self._debug_name(root)))expr2_=fn(expr2)# Check whether given input shape values satisfy a specified equation s = fn(s').# - Raise when the equation was violated by the given input shape values.# - Otherwise issue a guard to constrain them.concrete_val=self.evaluate_expr(sympy.Eq(expr1,expr2_))ifnotconcrete_val:raiseConstraintViolationError(f"Expected input {src.name()} to be equal to "f"{fn(sympy.Symbol(debug_name))}, "f"where {debug_name} = {expr2.subs(self.var_to_val)}, "f"but got {expr1.subs(self.var_to_val)}")forphantom_symbolinequalities_inputs.phantom_symbols:# we created additional phantom symbols that are not input shape dimensionssymbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol])# How do we know what the value of s0 is? Fresh variables can only be# bound by inputs, so there MUST be some other input which binds the# variable. If there is no such input, this is an error in our# system. We record where all symbols come from, to help you diagnose# why those symbols didn't occur.## In fact, generally speaking it is only possible for the "outermost"# user of a ShapeEnv to evaluate the guards, because some inputs may# not be available to inner levels. For example, Dynamo can guard on# tensors that never actually become graph arguments (they are# pruned). In this case, only Dynamo knows about these arguments.deftrack_symint(source,val,constraint=None):log.debug("track_symint %s%s%s",LazyString(source.name),val,constraint)assertnotisinstance(val,SymInt)oris_symbolic(val)ifisinstance(val,SymInt)andval.node.maybe_as_int()isnotNone:val=val.node.maybe_as_int()ifisinstance(val,SymInt):s=val.node.exprifisinstance(s,sympy.Symbol):symbol_to_source[s].append(source)ifconstraintisnotNone:symbol_to_constraints[s].add(constraint)elifisinstance(-s,sympy.Symbol):symbol_to_source[-s].append(NegateSource(source))else:constraint_violated=Falseifisinstance(constraint,StrictMinMaxConstraint):# try inferring the ranges of the expr ssym_vrs={x:self.var_to_range.get(x,None)forxins.free_symbols}ifall(vrisnotNoneforvrinsym_vrs.values()):expr_vr=bound_sympy(s,sym_vrs)ifexpr_vr!=constraint.vr:# the expr and constrain ranges don't matchconstraint_violated=Trueelse:# some of the free symbols in s don't have rangesconstraint_violated=Trueelifisinstance(constraint,RelaxedUnspecConstraint):ifs.is_number:i=int(s)# Don't complain about 0/1 specialization, we# expect to have to compile in this case anywayifinotin(0,1):constraint_violated=Trueifconstraint_violated:defhint(s):sexpr=ShapeGuardPrinter(symbol_to_source,source_ref,self.var_to_sources).doprint(s)returnf"{sexpr}."var_with_range=self._render_range_for_constraint_violation(source,constraint)msg=(f"Not all values of {var_with_range} are valid because "f"{self._debug_name(source)} was inferred to be equal to ")record_constraint_violation(constraint.warn_only,self._debug_name(source),msg,hint=functools.partial(hint,s),)input_guards.append((source,s))else:s=sympy.Integer(val)input_guards.append((source,s))constraint_violated=Falseifisinstance(constraint,StrictMinMaxConstraint):constraint_violated=Trueelifisinstance(constraint,RelaxedUnspecConstraint):# Don't complain about 0/1 specialization, we# expect to have to compile in this case anywayifvalnotin(0,1):constraint_violated=Trueifconstraint_violated:var_with_range=self._render_range_for_constraint_violation(source,constraint)msg=(f"Not all values of {var_with_range} are valid because "f"{self._debug_name(source)} was inferred to be a constant ({val}).")record_constraint_violation(constraint.warn_only,self._debug_name(source),msg)fort,source,contextinzip(placeholders,sources,input_contexts):ifisinstance(source,str):fromtorch._dynamo.sourceimportLocalSourcesource=LocalSource(source)assertisinstance(source,Source)iftisNone:continueifisinstance(t,(SymInt,int)):track_symint(source,t)continueassertisinstance(t,Tensorlike)ifis_traceable_wrapper_subclass(t):fromtorch._dynamo.sourceimportAttrSourceassertisinstance(context,SubclassSymbolicContext)# For subclasses, we need to track symints on BOTH the outer# and inner tensors.sources_tensors_constraints=[(source,t,context.constraint_sizes)]attrs,_=t.__tensor_flatten__()forattrinattrs:inner_t=getattr(t,attr)inner_context=context.inner_contexts[attr]sources_tensors_constraints.append((AttrSource(source,attr),inner_t,inner_context.constraint_sizes))else:sources_tensors_constraints=[(source,t,context.constraint_sizes)]forsrc,curr_t,constraintinsources_tensors_constraints:ifis_sparse_any(curr_t):fori,ssinenumerate(curr_t.size()):property_source=TensorPropertySource(src,TensorProperty.SIZE,i)track_symint(property_source,ss,constraint[i])else:fori,ssinenumerate(curr_t.size()):property_source=TensorPropertySource(src,TensorProperty.SIZE,i)track_symint(property_source,ss,constraint[i])fori,ssinenumerate(curr_t.stride()):track_symint(TensorPropertySource(src,TensorProperty.STRIDE,i),ss)track_symint(TensorPropertySource(src,TensorProperty.STORAGE_OFFSET),curr_t.storage_offset())# 1. Every input must equal the final simplified symbolic expression# stored on the placeholder. Given a placeholder (s0*2, s1),# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.# This does a lot of work: it covers duck sizing and equality guards.exprs=[]self.dim_constraints=DimConstraints(symbol_to_source,self.var_to_val,set(symbol_to_constraints.keys()),self.source_name_to_debug_name,)ifnot_simplified:forsource,exprininput_guards:ifself._translation_validation_enabled:# Ignore sources that were not turned into SymInts.srcname=source.name()ifsrcnameinself.source_to_symbol:self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname],expr))# Small optimizationif(isinstance(expr,sympy.Symbol)andsymbol_to_source.get(expr)andsource==symbol_to_source[expr][0]):continue# This logic excludes static values found on tensors from guarding, because# dynamo's check_tensor_fn does that (see guards.cpp).# However, for non tensor sources, we still need to guard here.ifignore_staticandisinstance(source,TensorPropertySource):ifexpr.is_number:self.log.debug("Skipping guard %s",f"{source_ref(source)} == {expr}")continueifis_dim(source):self.dim_constraints.add_equality(source,expr)sexpr=ShapeGuardPrinter(symbol_to_source,source_ref,self.var_to_sources).doprint(expr)exprs.append(f"{source_ref(source)} == {sexpr}")if(isinstance(source,TensorPropertySource)andsource.propisTensorProperty.SIZEandequalities_inputsandlen(expr.free_symbols)==1):symbol=next(iter(expr.free_symbols))if(isinstance(expr,sympy.Symbol)andexprinsymbol_to_constraintsandnotequalities_inputs.is_equal(source,symbol_to_source[expr][0])):msg=(f"The values of {self._debug_name(source)} = {source.name()} and "f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} ""must always be equal.")record_constraint_violation(equalities_inputs.warn_only,self._debug_name(source),msg)if(notisinstance(expr,sympy.Symbol)andsymbolinsymbol_to_constraintsandnotequalities_inputs.is_derived(source,symbol_to_source[symbol][0],lambdax:expr.subs(symbol,x))):src=symbol_to_source[symbol][0]msg=(f"The values of {self._debug_name(source)} = {source.name()} must always be related to "f"the values of {self._debug_name(src)} = {src.name()} by "f"{self._debug_name(source)} = {expr.subs(symbol,sympy.sympify(self._debug_name(src)))}.")record_constraint_violation(equalities_inputs.warn_only,self._debug_name(source),msg)# NB: Not necessary to report constraint violations here:# constraints are guaranteed to be on symbols (we've already# caught constants and non-atomic expressions), so we only# have relational constraints, but we don't support those# at the moment# 2. Every guard must evaluate to True (but remember many guards# like s0 == s1*2 because trivial due to simplification)issued=set()defissue_guard(guard:ShapeGuard)->None:expr=self.simplify(guard.expr)# Avoid re-issueing the same guard.ifexprinissued:returnissued.add(expr)try:is_trivial=Falseifany(is_dim(source)forsinexpr.free_symbolsforsourceinsymbol_to_source[s]):is_trivial=self.dim_constraints.add(expr)guard_expr=ShapeGuardPrinter(symbol_to_source,source_ref,self.var_to_sources).doprint(expr)exprs.append(guard_expr)self._add_target_expr(expr)# A non-relational constraint on a single sizevar can violate# a constraintifnotis_trivialandlen(expr.free_symbols)==1:symbol=next(iter(expr.free_symbols))source=symbol_to_source[symbol][0]constraints=symbol_to_constraints[symbol]forcinconstraints:ifisinstance(c,StrictMinMaxConstraint):var_with_range=self._render_range_for_constraint_violation(source,c)msg=(f"Not all values of {var_with_range} "f"satisfy the generated guard {guard_expr}.")record_constraint_violation(c.warn_only,self._debug_name(source),msg)elifisinstance(c,RelaxedUnspecConstraint):# This is fine, we allow guards here as long as it# didn't constrain it to one value (we don't# actually know this; this depends on our# ValueRanges reasoning capability)passelse:raiseAssertionError(f"unrecognized constraint {c}")exceptException:self.log.warning("Failing guard allocated at: \n%s",''.join(guard.stack.format()))raise# First, issue all the non-trivial guards.forguardinself.guards:ifself._maybe_evaluate_static(guard.expr)isnotNone:continueissue_guard(guard)# 3. Every symbol must be within its value range (this handles 0/1# specialization too).forsymbol,sourcesinsymbol_to_source.items():r=self.var_to_range.get(symbol)ifrisNone:ifsymbolnotinself.var_to_range:continuer=self.var_to_range[symbol]assertsourcesassertsymbol.is_integerbounds=[]ifr.lower!=-sympy.oo:ifany(is_dim(source)forsourceinsources):self.dim_constraints.add(sympy.Ge(symbol,r.lower))# Only print lower bound in simplified mode if it is not the# defaultifnot_simplifiedorr.lower!=self._default_value_range().lower:bounds.append(str(r.lower))bounds.append(source_ref(sources[0]))# NB: This looks like an off-by-one error but it's not: the# upper bound may be sys.maxsize - 1 because we intentionally# exclude sys.maxsize from our bounds to deal with direct# == INT_MAX guards, but it's still dumb to actually test it.# Note that you can be off by a pretty large constant and it# won't matter because sizes in practice will be no where near# the 64-bit limit.ifr.upper!=sympy.ooandr.upper<sys.maxsize-1:ifany(is_dim(source)forsourceinsources):self.dim_constraints.add(sympy.Le(symbol,r.upper))# nontrivial upper bound is always interestingbounds.append(str(r.upper))iflen(bounds)>1:exprs.append(" <= ".join(bounds))# Check constraintsconstraints=symbol_to_constraints[symbol]forcinconstraints:ifisinstance(c,StrictMinMaxConstraint):# NB: By default, we have a restrictive range# 2 <= s0 <= sys.maxsize - 1. But export users generally# expect to be able to specify nice ranges like [0, oo]ifnot(c.vr&self._default_value_range()).issubset(r):source=sources[0]expr=sympy.And(sympy.Le(r.lower,symbol),sympy.Le(symbol,r.upper))guard_expr=ShapeGuardPrinter(symbol_to_source,source_ref,self.var_to_sources).doprint(expr)var_with_range=self._render_range_for_constraint_violation(source,c)msg=(f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}")record_constraint_violation(c.warn_only,self._debug_name(source),msg,)ifconstraint_violations:warn_msgs=[]error_msgs=[]debug_names=set()forwarn_only,debug_name,msginconstraint_violations:ifwarn_only:msg=f" {len(warn_msgs)+1}. {msg()}"warn_msgs.append(msg)else:msg=f" - {msg()}"error_msgs.append(msg)debug_names.add(debug_name)iflen(error_msgs)>0:debug_names=', '.join(debug_names)err='\n'.join(error_msgs)raiseConstraintViolationError(f"Constraints violated ({debug_names})! ""For more information, run with TORCH_LOGS=\"+dynamic\".\n"f"{err}")eliflen(warn_msgs)>0:log.debug("%s Warning only constraints violated",len(warn_msgs))signpost_event("dynamic","produce_guards",{**self.co_fields,**self.counter,"num_guards":len(exprs),"free_symbols":sum(1forvinsymbol_to_source.values()ifv),# The keys are meaningless from an aggregate perspective, so# don't include them. Biggest first."symbol_guard_counts":sorted(self.symbol_guard_counter.values(),reverse=True),},)ifself._translation_validation_enabled:fromtorch.fx.experimental.validatorimportPopulateValidator# Add all deferred runtime assertions; these are not technically# handled by produce_guards but we need to put them in the target# setforrasinself.deferred_runtime_asserts.values():forrainras:self._add_target_expr(ra.expr)# Add value range bound guards for all symbols with no trivial bounds.# Reason: '_maybe_evaluate_static' may eliminate guards based on the# refined value ranges.forsym,vrinself.var_to_range.items():ifvr.lower!=-sympy.oo:self._add_target_expr(sympy.Le(vr.lower,sym))ifvr.upper!=sympy.oo:self._add_target_expr(sympy.Le(sym,vr.upper))# Before validating, populate the input of the validator with the# built FX graph.withfx_traceback.preserve_node_meta():PopulateValidator(self.graph,self.validator).run()self._check_translation_validate()returnexprs
[docs]defproduce_guards_expression(self,placeholders,ignore_static=True):""" Expected to be used with evaluate_guards_expression(). Produces the guards for the given placeholders and returns a string expression to be evaluated by evaluate_guards_expression given concrete values for the placeholders. """fromtorch._dynamo.sourceimportLocalSourcearg_names=[f"t{i}"foriinrange(len(placeholders))]guards=self.produce_guards(placeholders,[LocalSource(a)forainarg_names],ignore_static=ignore_static)ifguards:return" and ".join(guards)returnNone
[docs]defevaluate_guards_expression(self,code,args):""" Expected to be used with produce_guards_expression(). Evaluates an expression generated by produce_guards_expression for the given concrete args. """arg_names=[f"t{i}"foriinrange(len(args))]returneval(code,SYMPY_INTERP,{"L":dict(zip(arg_names,args))})
[docs]defevaluate_guards_for_args(self,placeholders,args,*,ignore_static=True):"""Generate guards for a graph's placeholder values and evaluate the guards with args """code=self.produce_guards_expression(placeholders,ignore_static=ignore_static)ifcode:returnself.evaluate_guards_expression(code,args)returnTrue
[docs]defbind_symbols(self,placeholders,args):""" Given a paired list of placeholders (fake tensors with symbolic sizes) and concrete arguments (regular tensors with real sizes), returns a dictionary mapping each symbol to its real value. So for example, if you have a placeholder with size (s0, s1), binding (2, 4) to it will give you {s0: 2, s1: 4}. This is not guaranteed to bind ALL symbols in the ShapeEnv; we can't bind a symbol if it doesn't occur in any placeholder, and symbols that already have replacements won't get bindings. This is a little duplicative with evaluate_guards but it's different enough that it seemed cleanest to make another copy. This assumes the guards are already checked, though if it's cheap we'll check for shenanigans """bindings:Dict[sympy.Symbol,int]={}defbind_symint(arg,val):ifisinstance(val,SymInt):s=val.node.exprifisinstance(s,sympy.Symbol):ifsinbindings:assertbindings[s]==arg,f"{bindings[s]} != {arg}"else:bindings[s]=argelifisinstance(-s,sympy.Symbol):if-sinbindings:assertbindings[-s]==-arg,f"{bindings[-s]} != {-arg}"else:bindings[-s]=-argfort,arginzip(placeholders,args):iftisNone:continueifisinstance(t,SymInt):bind_symint(arg,t)continueassertisinstance(t,torch.Tensor)fori,sinenumerate(t.size()):bind_symint(arg.size(i),s)fori,sinenumerate(t.stride()):bind_symint(arg.stride(i),s)bind_symint(arg.storage_offset(),t.storage_offset())returnbindings
[docs]defget_nontrivial_guards(self):"""Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""return[self.simplify(guard.expr)forguardinself.guardsifself._maybe_evaluate_static(guard.expr)isNone]
[docs]defformat_guards(self,verbose=False):"""Format this shape env's guard expressions with optional traceback info if verbose"""defformat_tb(tb):ifnotverbose:return""returnf"\n Guarded at:\n{''.join(' '+lforlintb.format())}"return'\n'.join(f" - {guard.expr}{format_tb(guard.stack)}"forguardinself.guards)
[docs]defbound_sympy(self,expr:sympy.Expr,size_oblivious:bool=False)->ValueRanges:"""Given a sympy expression, computes a ValueRanges bound for what values it can be"""var_to_range={x:self.var_to_range.get(x,None)forxinexpr.free_symbols}ifsize_oblivious:# Clamp values of size-like variablesforxinself.size_like&var_to_range.keys():ifvar_to_range[x]isnotNone:var_to_range[x]&=ValueRanges(2,sympy.oo)returnbound_sympy(expr,var_to_range)
@_lru_cachedef_maybe_evaluate_static(self,expr:"sympy.Expr",*,unbacked_only:bool=False,compute_hint:bool=False,expect_rational=True,size_oblivious:bool=False)->"Optional[sympy.Expr]":""" Tries to evaluate expr without introducing guards If unbacked_only == True, then we only do substitutions on unbacked SymInts (leaving regular hinted integers alone). This could result in an expression that still contains backed SymInts, which you could then potentially guard on. Use compute_hint == True if you are trying to compute a non-binding hint for the particular hint values of backed SymInts, e.g., if s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. """expr=self.simplify(expr)ifcompute_hint:expr=expr.xreplace(self.var_to_val)expr=canonicalize_bool_expr(expr)symbols=list(expr.free_symbols)# Apply known runtime assertsforsinsymbols:# Unbacked symints onlyifsinself.var_to_val:continuesubst={}defadd_expr(expr):# Expr and negationsubst[canonicalize_bool_expr(expr)]=sympy.truesubst[canonicalize_bool_expr(sympy.Not(expr))]=sympy.falseifisinstance(expr,sympy.Rel):# multiplying by -1 changes the direction of the inequalitydual=type(expr)(-expr.rhs,-expr.lhs)subst[canonicalize_bool_expr(dual)]=sympy.truesubst[canonicalize_bool_expr(sympy.Not(dual))]=sympy.falseforeinitertools.chain(self.guards,self.deferred_runtime_asserts.get(s,())):e=e.exprifcompute_hint:e=canonicalize_bool_expr(e.xreplace(self.var_to_val))add_expr(e)# Other relational expressions this expression impliesifisinstance(e,sympy.Eq):add_expr(sympy.Le(e.lhs,e.rhs))add_expr(sympy.Ge(e.lhs,e.rhs))elifisinstance(e,sympy.Lt):add_expr(sympy.Le(e.lhs,e.rhs))add_expr(sympy.Ne(e.lhs,e.rhs))# NB: this helps us deal with And/Or connectivesexpr=expr.subs(subst)# Simplify making use of value range lower boundnew_shape_env={}new_range_env={}foridx,kinenumerate(symbols):ifisinstance(self.var_to_val.get(k,None),SingletonInt):# Skip var_to_range logic for SingletonInt which is only used# for jagged layout NestedTensors todaycontinuevr=self.var_to_range[k]ifsize_obliviousandkinself.size_like:lower=max(2,vr.lower)else:lower=vr.lower# Don't do anything if we don't have a nontrivial lower bound# Also don't do anything if we asked only to simplify unbacked# SymIntif(lower<(-sys.maxsize-1)//2or(unbacked_onlyandkinself.var_to_val)):new_range_env[k]=vrcontinue# Positive means >= 1# Positive - 1 means >= 0# Positive + lower - 1 means >= lower# The new symbol 's' is "too low", so when we substitute it in# we have to increase it by offset (and conversely, the new# variables have to have their value range bounds adjusted as# well)s=sympy.Symbol(f"shape_{idx}",positive=True,integer=True)offset=lower-1new_shape_env[k]=s+offsetnew_range_env[s]=SymPyValueRangeAnalysis.add(vr,-offset)defreplace(expr,repl):returnexpr.xreplace(repl)try:new_expr=replace(expr,new_shape_env)exceptRecursionError:log.warning("RecursionError in sympy.xreplace(%s, %s)",expr,new_shape_env)self.counter["sympy_recursion_error"]+=1returnNonefloor_div_replace={}foratominnew_expr.atoms(FloorDiv):floor_div_replace[atom]=sympy.floor(atom.args[0]/atom.args[1])new_expr=safe_expand(new_expr.xreplace(floor_div_replace))# TODO: when unbacked_only, can sometimes early return even when there# are still free symbolsifnew_expr.is_number:returnnew_expr# Check if the range can solve it staticallyout=bound_sympy(new_expr,new_range_env)ifexpect_rational:_assert_bound_is_rational(new_expr,out)ifout.is_singleton():returnout.lowerreturnnew_exprifunbacked_onlyelseNone
[docs]@_lru_cachedefreplace(self,expr:"sympy.Expr")->"sympy.Expr":"""Apply symbol replacements to any symbols in the given expression """replacements={s:self._find(cast(sympy.Symbol,s))forsinexpr.free_symbols}returnsafe_expand(expr.xreplace(replacements))
[docs]@_lru_cachedefsimplify(self,expr:"sympy.Expr")->"sympy.Expr":"""Use known constraints and replacements to simplify the given expr """expr=self.replace(expr)# TODO it would seem that this pass is not necessary given the# below replacement of // with /, but for nested FloorDivs# the non-recursive replacement doesn't work, and# recursive makes it hard to look up divisibility,# because existing divisibility info has FloorDiv in it, not /# for now just do a separate pass to catch common nested caseifexpr.has(FloorDiv):self._update_divisible()div_replacements={}foratominexpr.atoms(FloorDiv):base,divisor=atom.argsifisinstance(divisor,FloorDiv):base1,divisor1=divisor.argsifself.replace(Mod(base,divisor))inself.divisibleand \
base==base1andself.replace(Mod(base1,divisor1))inself.divisible:div_replacements[atom]=divisor1expr=expr.xreplace(div_replacements)expr=safe_expand(expr)ifexpr.has(FloorDiv):div_replacements={}pows=expr.atoms(sympy.Pow)rationals=expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))forfdinexpr.atoms(FloorDiv):base,divisor=fd.argsifself.replace(Mod(base,divisor))inself.divisible:div_replacements[fd]=base/divisornew_expr=expr.xreplace(div_replacements)new_expr=safe_expand(new_expr)new_pows=new_expr.atoms(sympy.Pow)new_rationals=new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer))# divisions simplified awayifnew_pows.issubset(pows)andnew_rationals.issubset(rationals):expr=new_exprreturnexpr
[docs]@lru_cache(256)defsize_hint(self,expr:"sympy.Expr",*,allow_none=False):""" Gets a size hint for a given expression from the underlying shapes we had. Does not introduce a guard, so only use this when you can guarantee that your code is still valid for arbitrary shapes (such as optimization decisions) """result_expr=safe_expand(expr).xreplace(self.var_to_val)ifnotresult_expr.is_number:fromtorch.utils._sympy.singleton_intimportSingletonIntifisinstance(result_expr,SingletonInt):returnNoner=self._maybe_evaluate_static(result_expr,compute_hint=True)ifrisnotNone:returnrifallow_none:returnNoneraiseself._make_data_dependent_error(result_expr,expr)returnresult_expr
# NB: keep in sync with size_hint@lru_cache(256)defhas_hint(self,expr:"sympy.Expr"):result_expr=safe_expand(expr).xreplace(self.var_to_val)returnresult_expr.is_numberorself._maybe_evaluate_static(result_expr)isnotNonedef_make_data_dependent_error(self,expr,unhinted_expr,*,size_oblivious_result:Optional[bool]=None):# TODO: in a Dynamo context, having user code, and having the# name of the local, will be much bettersize_like_symbols=[]forsinexpr.free_symbols:stacktrace=''.join(self.var_to_stack[s].format())self.log.debug("Data dependent variable '%s' allocated at:\n%s",s,stacktrace)ifsinself.size_like:size_like_symbols.append(s)size_oblivious_result_msg=""ifsize_oblivious_resultisnotNone:size_oblivious_result_msg=(f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n""Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n")fsummary,maybe_user_loc,maybe_extra_debug=self._get_stack_summary(True)returnGuardOnDataDependentSymNode(f"Could not guard on data-dependent expression {expr} (unhinted: {unhinted_expr}). "f"(Size-like symbols: {', '.join(map(str,size_like_symbols))or'none'})\n\n"f"{size_oblivious_result_msg}""Potential framework code culprit (scroll up for full backtrace):\n"f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n""For more information, run with TORCH_LOGS=\"dynamic\"\n""For extended logs when we create symbols, also add "f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str,expr.free_symbols))}\"\n""If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n""For more debugging help, see ""https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n"+maybe_extra_debug# TODO: Help text about how to use our runtime tests to fix this# problem)def_set_replacement(self,a:"sympy.Symbol",tgt:"sympy.Expr",msg:str)->None:""" Adds or updates a replacement for a symbol. Use this instead of `self.replacements[a] = tgt`. """# Precondition: a == tgtassertisinstance(a,sympy.Symbol)# Handles nested tensor symbolic variables which don't have# var_to_range boundstgt_bound=Noneifainself.var_to_range:src_bound=self.var_to_range[a]# If you have x in [2, maxint], then 2*x in [4, 2*maxint].# But we don't really care that the max bound says we can# go beyond the maximum integer size, because we aren't# using bigints anyway. Arguably, ValueRanges should know# to do this truncation automaticaly (to avoid doing# bigint compute in range analysis), but right now it doesn't# so we need to get rid of some unnecessary precision.int_range=ValueRanges(-sys.maxsize-1,sys.maxsize-1)defissubset(x,y):return(x&int_range).issubset(y&int_range)# First, refine the value range of a based on the computed value range# of tgt. This is always OK to do, even if we decide not to do the# substitution in the end. This might be a no-op, if a already has# a tighter boundtgt_bound=self.bound_sympy(tgt)self.var_to_range[a]=src_bound&tgt_bound# Next, check if we can update the range of free symbols in tgt# based on the range in a. But only do it if:# - the source bound non-trivially improves over what we get out of# the existing bounds.# - the replacement is univariate and we can invert the tgt expressionifnotissubset(tgt_bound,src_bound)andlen(tgt.free_symbols)==1:b=next(iter(tgt.free_symbols))# Try to invert the equalityr=try_solve(sympy.Eq(a,tgt),b,floordiv_inequality=False)ifrisnotNone:b_bound=self.bound_sympy(r[1])self.var_to_range[b]=b_bound&self.var_to_range[b]tgt_bound=self.bound_sympy(tgt)assertissubset(tgt_bound,src_bound)# TODO: Should we propagate size-like-ness?## Pros: if u0 is size-like, intuitively u0 == u1 should cause u1# to become size-like.## Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T# propagate in this case, because what if u0 == 0, then u1 is negative# and clearly isn't a size. So, at minimum, any f(x) whose value# range isn't [0, inf] given x in [0, inf] cannot propagate# size-like-ness. But there are many situations where you could# imagine u1 is going to be size-like and actually you just didn't# have a refined enough value range on u0. Since even innocuous# looking arithmetic operations can destroy size-like-ness, it's# best to not propagate it at all and force the user to annotate it# as necessary.## Compromise: we preserve size-like-ness only for exact equality# and nothing else.ifainself.size_likeandisinstance(tgt,sympy.Symbol):self.size_like.add(tgt)elifisinstance(tgt,sympy.Symbol)andtgtinself.size_like:self.size_like.add(a)# Now, decide if we will do the substitution.## - If the source has a non-trivial range, only substitute if# we preserve this range. Note that we may have propagated# the src_range to free variables in tgt when tgt is univariate# and we could find an inverse, which helps us achieve this.# This ensures we never "forget" about user defined ranges,# even if they end up being defined on composite formulas# like s0 + s1.## - If the variable is unbacked, only substitute if the substitution# would preserve the bounds also under size-like-ness conditions.ifnotissubset(tgt_bound,src_bound):self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]",a,tgt,msg,tgt_bound,src_bound)returnelifainself.size_like:tgt_bound_so=self.bound_sympy(tgt,size_oblivious=True)# This is morally equivalent to self.bound_sympy(a, size_oblivious=True)# but handles substitutions like u0 == 0src_bound_so=self.var_to_range[a]ifsrc_bound_so.upper>=2:src_bound_so&=ValueRanges(2,sympy.oo)ifnotissubset(tgt_bound_so,src_bound_so):self.log.debug("skipped set_replacement %s = %s (%s) ""[%s not subset of %s (size-oblivious conditions)]",a,tgt,msg,tgt_bound_so,src_bound_so)returnifconfig.print_specializationsandisinstance(tgt,(sympy.Integer,sympy.Float)):# specializing to a constant, which is likely unexpected# NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g.,# when adding a to self.replacements, and again when simplifying an expression containing a.# Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is,# it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant.ifanotinself.replacementsortgt!=self.replacements[a]:self.log.warning("Specializing %s to %s",self.var_to_sources[a][0].name(),tgt)self.log.debug("SPECIALIZATION",stack_info=True)log.info("set_replacement %s = %s (%s) %s",a,tgt,msg,tgt_bound)self.replacements[a]=tgtself._update_version_counter()# When specializing 'a == tgt', the equality should be also conveyed to# Z3, in case an expression uses 'a'.self._add_target_expr(sympy.Eq(a,tgt))def_add_divisible(self,expr:"sympy.Expr"):self.divisible.add(expr)self._update_version_counter()@_lru_cache@record_shapeenv_event()def_find(self,a:"sympy.Symbol")->"sympy.Expr":""" Implements a DSU-like algorithm to find the variable that represents a Also handles transitive non-identity replacements. a: b + c c: d """ifanotinself.replacements:returnares=self.replacements[a]cur_replace={s:self._find(s)forsinres.free_symbols}self._set_replacement(a,self.replacements[a].xreplace(cur_replace),"find")returnself.replacements[a]@lru_cache(256)def_maybe_guard_rel(self,expr:"sympy.Rel")->None:""" The relational guard is guarded to be true. Use this information to simplify shapes (i.e. a == b or a % 5 == 0) """assertisinstance(expr,sympy.Rel)# A good example of what goes wrong if you don't do this is# python test/functorch/test_aotdispatch.py -k# test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32ifisinstance(expr,sympy.Ne):returnfree=list(expr.free_symbols)assertlen(free)>0,f"The expression should not be static by this point: {expr}"# In case of really gnarly expression, we don't blow upiflen(free)>5:return# Prioritize unbacked symints for solving by ordering them last.# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)# Prefer to simplify out symbols with ephemeral sources.def_smart_symbol_sort(x):has_only_ephemeral_sources=(xinself.var_to_sourcesandall(s.is_ephemeral()forsinself.var_to_sources[x]))size=self.size_hint(x,allow_none=True)orsys.maxsizename=x.name# 1 puts ephemeral sourced symbols first when sorting in reversereturn(1ifhas_only_ephemeral_sourceselse0,size,name)free=sorted(free,key=_smart_symbol_sort,reverse=True)# type: ignore[attr-defined]lhs=expr.lhsrhs=expr.rhsself._refine_ranges(expr)# The rest of this stuff is for equality onlyifnotisinstance(expr,sympy.Eq):returnifnotexpr.has(Mod):try:floor_div_atoms=lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv))iflen(floor_div_atoms)>0andany(a.divisor!=1forainfloor_div_atoms):raiseNotImplementedError# short-circuit when no solving is neededifisinstance(lhs,sympy.Symbol)andfree_unbacked_symbols(lhs):self._set_replacement(lhs,self._find(rhs),"trivial_lhs")elifisinstance(rhs,sympy.Symbol)andfree_unbacked_symbols(rhs):self._set_replacement(rhs,self._find(lhs),"trivial_rhs")else:r=try_solve(expr,free[0],floordiv_inequality=False)ifrisnotNoneandall(t.is_integerfortinsympy.preorder_traversal(r[1])):new_var=self._find(r[1])ok=Falseifself.is_unbacked_symint(free[0]):# If you have i0 + i1 + i2 = s0, don't substitute i2 =# s0 - i0 - i1. Arguably this should be OK but the# runtime assert machinery is very delicate right now# so this causes things to fail e.g.,# test_split_unbacked_sizesok=len(free_unbacked_symbols(new_var))<=1msg="solve_unbacked"else:# Never substitute backed with unbackedok=len(free_unbacked_symbols(new_var))==0msg="solve_backed"ifok:self._set_replacement(cast(sympy.Symbol,free[0]),new_var,msg)exceptNotImplementedError:passifexpr.has(Mod):mod_expr=next(iter(expr.atoms(Mod)))try:r=try_solve(expr,mod_expr,floordiv_inequality=False)ifrisnotNoneandr[1]==0:self._add_divisible(mod_expr)# This is a little bit of extra logic to make things like# torch.empty(i0, q).view(c, -1, q) work outp,q=mod_expr.argsifisinstance(q,sympy.Number)andisinstance(p,sympy.Mul)andlen(p.args)==2:c,i0=p.args# Given Mod(c * i0, q) == 0if(isinstance(c,sympy.Number)andisinstance(i0,sympy.Symbol)andself.is_unbacked_symint(i0)):# We have Mod(i0, q / c) == 0, which means we can# rewrite i0 as (q / gcd(q, c)) * i1d=q/sympy.gcd(q,c)i1=self.create_unbacked_symint().node.expr# Propagate the value ranges. It doesn't really# matter if we use truediv or floordiv, because we# have established divisibility.self.var_to_range[i1]=SymPyValueRangeAnalysis.truediv(self.var_to_range[i0],ValueRanges.wrap(d))# Propagate size-like-nessifi0inself.size_like:self.size_like.add(i1)self._set_replacement(i0,d*i1,"divisibility")exceptNotImplementedError:passreturn# See: Note - On 0/1 specialization# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT# as a sentinel sometimes. Your sizevar isn't going to be# anywhere near the max 64-bit integer anyway.def_default_value_range(self)->ValueRanges:lower=2ifself.specialize_zero_oneelse0returnValueRanges(lower,sys.maxsize-1)def_default_unspecified_value_range(self)->ValueRanges:returnValueRanges(-sys.maxsize-1,sys.maxsize)@_lru_cachedef_simplify_floor_div(self,expr):floor_divs=tuple(expr.atoms(FloorDiv))# we expect floor_divs to be exact,# and thus add the guards for the exact floordivs,# even if tracing doesn't require them otherwiseforfdinreversed(floor_divs):base,divisor=fd.argsmod_expr=Mod(base,divisor)eq_expr=sympy.Eq(mod_expr,0)# add necessary mod guardsself.evaluate_expr(eq_expr)returnself.simplify(expr)# We're about to add a guard/runtime assert, check if the ShapeEnv is frozen# and if so issue a warningdef_check_frozen(self,expr,concrete_val):ifself.frozen:self.counter["ignored_backward_guard"]+=1signpost_event("dynamic","evaluate_expr_frozen",{**self.co_fields,"ignored_guard":f"{expr} == {concrete_val}",# no version = original state (this signpost is expected)# version 2 = dynamic backwards is eagerly compiled"version":2,},)log.warning("Ignored guard %s == %s, this could result in accuracy problems",expr,concrete_val)def_get_stack_summary(self,is_debug:bool=False):fsummary=Noneframe=inspect.currentframe()try:whileframeisnotNone:ifframe.f_code.co_filenamenotinuninteresting_files():fsummary=traceback.FrameSummary(frame.f_code.co_filename,frame.f_lineno,frame.f_code.co_name,)breakframe=frame.f_backfinally:delframe# NB: this stack is truncated, but it's fine because the main# stack_info will give you the rest of the info you needmaybe_user_loc=""user_tb=TracingContext.extract_stack()ifuser_tb:maybe_user_loc=" at "+format_frame(user_tb[-1])maybe_extra_debug=""ifis_debuganduser_tb:maybe_extra_debug=('\nUser Stack (most recent call last):\n'+' (snipped, see stack below for prefix)\n'+''.join(traceback.format_list(user_tb)))ifis_debugandconfig.extended_debug_cpp:cpp_stack=CapturedTraceback.extract(cpp=True)maybe_extra_debug+="\nC++ stack trace:\n"+''.join(cpp_stack.format())returnfsummary,maybe_user_loc,maybe_extra_debugdef_log_guard(self,prefix:str,g,forcing_spec:bool):ifself.log.isEnabledFor(logging.INFO):str_g=str(g)is_debug=config.extended_debug_guard_addedisnotNoneandstr_g==config.extended_debug_guard_addedfsummary,maybe_user_loc,maybe_extra_debug=self._get_stack_summary(is_debug)self.log.info("%s%s [guard added]%s (%s)%s",prefixifnotforcing_specelsef"{prefix} (forcing_spec)",str_g,maybe_user_loc,format_frame(fsummary),maybe_extra_debug,stack_info=is_debug,)
[docs]@lru_cache(256)@record_shapeenv_event(save_tracked_fakes=True)defevaluate_expr(self,orig_expr:"sympy.Expr",hint=None,fx_node=None,expect_rational=True,size_oblivious:bool=False,*,forcing_spec:bool=False):""" Given an expression, evaluates it, adding guards if necessary """# TODO: split conjunctions and evaluate them separately@lru_cache(None)defcompute_concrete_val():ifhintisNone:returnself.size_hint(orig_expr)else:returnsympy.sympify(hint)# Check if:# 1. 'translation_validation' is set# 2. the corresponding 'fx_node' is not 'None'# 3. the guard should not be suppressed## If all of the above check, we create an FX node representing the# actual expression to be guarded.node=Nonefresh=Falseif(self._translation_validation_enabledandfx_nodeisnotNoneandnotself._suppress_guards_tls()andnotsize_oblivious):concrete_val=compute_concrete_val()ifconcrete_valissympy.true:node,fresh=self._create_fx_call_function(torch._assert,(fx_node,))elifconcrete_valissympy.false:neg,_=self._create_fx_call_function(operator.not_,(fx_node,))node,fresh=self._create_fx_call_function(torch._assert,(neg,))else:eql,_=self._create_fx_call_function(operator.eq,(fx_node,concrete_val))node,fresh=self._create_fx_call_function(torch._assert,(eql,))assertnodeisnotNone# If this is a fresh node, we have to remember the event index that# corresponds to this assertion node.# Reason: so that, given an assertion node, we can replay the ShapeEnv# events until the point where this assertion node was freshly created.iffresh:self._add_fx_node_metadata(node)# After creating the FX node corresponding to orig_expr, we must make sure that# no error will be raised until the end of this function.## Reason: the translation validation may become invalid otherwise.## If an error is raised before the end of this function, we remove the FX node# inserted, and re-raise the error.guard=Nonetb=Nonetry:iforig_expr.is_number:self.log.debug("eval %s [trivial]",orig_expr)# NB: don't test float as there may be precision issuesifisinstance(hint,(int,bool)):assertorig_expr==hint,f"{orig_expr} != {hint}"returnorig_exprexpr=orig_exprstatic_expr=self._maybe_evaluate_static(expr,expect_rational=expect_rational,size_oblivious=size_oblivious)ifstatic_exprisnotNone:self.log.debug("eval %s == %s [statically known]",orig_expr,static_expr)# NB: don't test float as there may be precision issuesifisinstance(hint,(int,bool)):assertstatic_expr==hint,f"{static_expr} != {hint}"returnstatic_exprifnot(expr.free_symbols<=self.var_to_val.keys()):# TODO: dedupe this with _maybe_evaluate_static# Attempt to eliminate the unbacked SymIntnew_expr=self._maybe_evaluate_static(expr,unbacked_only=True)ifnot(new_expr.free_symbols<=self.var_to_val.keys()):size_oblivious_result=Noneifnotsize_oblivious:size_oblivious_result=self._maybe_evaluate_static(expr,expect_rational=expect_rational,size_oblivious=True)raiseself._make_data_dependent_error(expr.xreplace(self.var_to_val),expr,size_oblivious_result=size_oblivious_result)expr=new_exprconcrete_val=compute_concrete_val()self._check_frozen(expr,concrete_val)if(config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLYandisinstance(hint,bool)andisinstance(expr,(sympy.Eq,sympy.Ne))):expr=sympy.Not(expr)# Turn this into a boolean expression, no longer need to consult# concrete_valsuppress_maybe_guard_rel=Falseifconcrete_valissympy.true:g=exprelifconcrete_valissympy.false:g=sympy.Not(expr)else:# WARNING: we cannot actually do simplifications on guards# on floating point values, because Sympy generally does not# think expressions on integers can ever be equal to floating# point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without# very clear algebraic laws that hold for floating point, such# simplifications are error prone anyway, so be sure not to# maybe_guard_rel in those cases.ifnotisinstance(concrete_val,sympy.Integer):suppress_maybe_guard_rel=Trueg=sympy.Eq(expr,concrete_val)# type: ignore[arg-type]ifisinstance(g,sympy.Rel):# TODO: If we successfully eliminate a symbol via equality, it# is not actually necessary to save a guard for the equality,# as we will implicitly generate a guard when we match that# input against the symbol. Probably the easiest way to# implement this is to have maybe_guard_rel return a bool# saying if it "subsumed" the guard (and therefore the guard# is no longer necessary)self._maybe_guard_rel(g)ifnotself._suppress_guards_tls():stack=CapturedTraceback.extract(skip=1)guard=ShapeGuard(g,stack)# TODO: deal with duplicate guards somehowself.guards.append(guard)exceptException:iffresh:self._remove_fx_node(node)raiseelse:ifnotself._suppress_guards_tls():assertguardisnotNoneself._log_guard("eval",g,forcing_spec=forcing_spec)forsing.free_symbols:self.symbol_guard_counter[s]+=1# Forcing_spec to avoid infinite recursionif(notforcing_specandconfig.symbol_guard_limit_before_specializeisnotNoneandself.symbol_guard_counter[s]>config.symbol_guard_limit_before_specialize):# Force specializationself.log.info("symbol_guard_limit_before_specialize=%s exceeded on %s",config.symbol_guard_limit_before_specialize,s)self.evaluate_expr(s,forcing_spec=True)else:self.log.debug("eval %s [guard suppressed]",g)returnconcrete_val
[docs]defcleanup(self):""" Break reference cycles. This destroys the stacks. If you really want to keep them, we just need some way to break references on code objects. """forginself.guards:g.stack.cleanup()forsinself.var_to_stack.values():s.cleanup()forrasinself.deferred_runtime_asserts.values():forrainras:ra.stack.cleanup()
[docs]@record_shapeenv_event(save_tracked_fakes=True)defdefer_runtime_assert(self,orig_expr:"sympy.Expr",msg,fx_node=None):"""Create an assert that is checked at runtime Args: orig_expr (sympy.Expr): Boolean expression to assert is true msg (str): Message to display on assertion failure fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding to the expression, if applicable """expr=orig_expr# TODO: split conjunctions and evaluate them separatelystatic_expr=self._maybe_evaluate_static(expr)ifstatic_exprisnotNone:self.log.debug("runtime_assert %s == %s [statically known]",orig_expr,static_expr)returnstatic_expr# Attempt to eliminate the unbacked SymIntnew_expr=self._maybe_evaluate_static(expr,unbacked_only=True)ifnew_expr.free_symbols<=self.var_to_val.keys():# Do a normal guardreturnself.evaluate_expr(new_expr,fx_node=fx_node)# NB: Don't use new_expr as expr; it could contain gunk like shape0# which we don't want to guard on# OK, we're definitely doing a runtime assert nowif(self._translation_validation_enabledandfx_nodeisnotNoneandnotself._suppress_guards_tls()):node,fresh=self._create_fx_call_function(torch._assert,(fx_node,))assertnodeisnotNoneiffresh:self._add_fx_node_metadata(node)self._check_frozen(expr,sympy.true)# eliminate symbols on equality tests / refine rangesifisinstance(expr,sympy.Rel):self._maybe_guard_rel(expr)ifnotself._suppress_guards_tls():# canonicalise to remove equations that are trivially equalorig_expr=exprexpr=canonicalize_bool_expr(expr)stack=CapturedTraceback.extract(skip=1)ra=RuntimeAssert(expr,msg,stack)# TODO: Do this in a way that is less janky than int(s.name[1:])cands=sorted([sforsinexpr.free_symbolsifs.name.startswith("u")],key=lambdas:int(s.name[1:]))self.deferred_runtime_asserts.setdefault(cands[-1],[]).append(ra)self.num_deferred_runtime_asserts+=1self._update_version_counter()self._log_guard("runtime_assert",orig_expr,forcing_spec=False)else:self.log.debug("runtime_assert %s [guard suppressed]",expr)returnTrue
# Refines the ranges of the variables present in 'guard'.## This function tries to refine the range of the variables inside# 'guard' by reasoning about it. Specifically, when 'guard' is a# 'sympy.Relational' operation.## It does mainly 3 things:# 1. Tries to isolate a variable in the left-hand side# 2. Compute the value range of the right-hand side# 3. Update the value range of the variable, if betterdef_refine_ranges(self,expr:sympy.Expr)->None:expr=self.simplify(expr)forsymbolinexpr.free_symbols:assertisinstance(symbol,sympy.Symbol)ifisinstance(self.var_to_val.get(symbol,None),SingletonInt):# Skip var_to_range logic for SingletonInt which is only used# for jagged layout NestedTensors todaycontinuer=try_solve(expr,symbol)ifrisNoneornot(symbol.is_integerandr[1].is_integer):# Range refinement only supports integer symbols for now.# There are lots of SymPy bugs when it comes to comparing# reals and integers, so we skip that for now.continuer_expr,rhs=rvr=self.var_to_range[symbol]lower,upper=vr.lower,vr.upperrhs_vr=bound_sympy(rhs,self.var_to_range)_assert_bound_is_rational(rhs,rhs_vr)# Let's suppose that we have a preexisting range for x [0, 100].# Now, we issue a guard x > y, where the range for y is [50, 150].# Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen,# refining x to [51, 100], since x must be greater than y, but the lowest# y could be is 50.## sympy.Eq may update both lower and upper bounds.# sympy.G{t,e} may update the lower bound, only.# sympy.L{t,e} may update the upper bound, only.iflower<rhs_vr.lowerandisinstance(r_expr,(sympy.Eq,sympy.Ge,sympy.Gt)):# Strictly greater relations allow us to refine a bit more, since# x < y implies that the lower bound for x is: y + 1.lower=rhs_vr.lower+int(isinstance(r_expr,sympy.Gt))ifupper>rhs_vr.upperandisinstance(r_expr,(sympy.Eq,sympy.Le,sympy.Lt)):upper=rhs_vr.upper-int(isinstance(r_expr,sympy.Lt))# Do nothing if the new value range is no better than what we already have.ifvr==ValueRanges(lower,upper):continue# Updates the range and the guards corresponding to each bound of the symbol.self.var_to_range[symbol]=ValueRanges(lower,upper)# Clears the cache, since this update can change the result.self._maybe_evaluate_static.cache_clear()
def_is_int(expr):returnisinstance(expr,SymInt)andexpr.node.expr.is_number# WARNING: This is legacy, DO NOT USEdef_is_dim_dynamic(t,d):returnhasattr(t,"_dynamo_dynamic_indices")anddint._dynamo_dynamic_indices
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.