importbuiltinsimportdataclassesimportinspectimportmathimportsysimportweakreffromcollectionsimportdefaultdictfromtypingimportAny,Callable,Dict,List,Optional,Set,Tuple,TYPE_CHECKING,Unionimporttorchfromtorch._subclasses.fake_tensorimportFakeTensorfromtorch.utils._pytreeimportSUPPORTED_NODESfrom.exported_programimportExportedProgramifTYPE_CHECKING:fromsympyimportSymbolfromtorch._guardsimportSourcefrom..fx.experimental.symbolic_shapesimportShapeEnv,StrictMinMaxConstraint__all__=["Constraint","Dim","dims","dynamic_dim"]class_Dim(type):""" Metaclass for :func:`Dim` types. """@staticmethoddefreadable(name,min_,max_):ifmin_==2:min_=Noneifmax_==sys.maxsize-1:max_=Noneifmin_isNoneandmax_isNone:returnf"Dim('{name}')"ifmin_isNone:returnf"Dim('{name}', max={max_})"ifmax_isNone:returnf"Dim('{name}', min={min_})"returnf"Dim('{name}', min={min_}, max={max_})"def__add__(cls,other):# e.g., dim + 1iftype(other)isnotint:raiseNotImplementedError(f"Attempted to add {other} to {cls.__name__}, where an integer was expected. ""(Only increasing linear operations with integer coefficients are supported.)")returncls._derive(lambdax:x+other)def__radd__(cls,other):returncls+otherdef__sub__(cls,other):# e.g., dim - 1iftype(other)isnotint:raiseNotImplementedError(f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. ""(Only increasing linear operations with integer coefficients are supported.)")returncls._derive(lambdax:x-other)def__rsub__(cls,other):raiseNotImplementedError(f"Attempted to negate {cls.__name__}. ""(Only increasing linear operations with integer coefficients are supported.)")def__mul__(cls,other):# e.g., dim * 2iftype(other)isnotintorother<=0:raiseNotImplementedError(f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. ""(Only increasing linear operations with integer coefficients are supported.)")returncls._derive(lambdax:x*other)def__rmul__(cls,other):returncls*otherdef_derived_name(cls,fn):fromsympyimportsympifyreturnstr(fn(sympify(cls.__name__)))def_derive(cls,fn):return_DerivedDim(cls._derived_name(fn),(int,),{"root":cls,"fn":fn})class_DerivedDim(_Dim):""" Metaclass for derived :func:`Dim` types. Currently we only support increasing linear expressions with integer coefficients. In other words, a derived Dim can always be written in the form Ax + B, where x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. (In particular, the latter ensures that x < y => Ax + B < Ay + B.) These restrictions on the form of derived Dims makes the metatheory simpler: e.g., it simplifies computing ranges for derived Dims, solving for underlying regular Dims, deciding equalities between derived Dims, and so on. The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. The range of a derived Dim is computed by mapping `fn` over the range of its `root`. """@propertydefmin(self):# assume that self.fn is an increasing function# TODO(avik): use sympy value range analysis instead?fromsympyimportInteger_min_symint=self.fn(Integer(self.root.min))# type: ignore[attr-defined]assert_min_symint>=2,(f"Expected derived min value of {self.__name__} to be >= 2. "f"Please specify an appropriate min value for {self.root.__name__} "# type: ignore[attr-defined]f"(currently {self.root.min})."# type: ignore[attr-defined])returnint(_min_symint)@propertydefmax(self):# assume that self.fn is an increasing function# TODO(avik): use sympy value range analysis instead?fromsympyimportInteger_max_symint=self.fn(Integer(self.root.max))# type: ignore[attr-defined]assert_max_symint<=sys.maxsize-1,(f"Expected derived max value of {self.__name__} to be <= {sys.maxsize-1}. "f"Please specify an appropriate max value for {self.root.__name__} "# type: ignore[attr-defined]f"(currently {self.root.max})."# type: ignore[attr-defined])returnint(_max_symint)def_derive(self,fn):# We support nesting, e.g., 2*dim + 1.# This is implemented by composing operations on the same root.# As a consequence, roots are always regular Dims (i.e., not derived Dims).return_DerivedDim(self._derived_name(fn),(int,),{"root":self.root,"fn":lambdax:fn(self.fn(x))},# type: ignore[attr-defined])
[docs]defDim(name:str,*,min:Optional[int]=None,max:Optional[int]=None):""" :func:`Dim` constructs a type analogous to a named symbolic integer with a range. It can be used to describe multiple possible values of a dynamic tensor dimension. Note that different dynamic dimensions of the same tensor, or of different tensors, can be described by the same type. Args: name (str): Human-readable name for debugging. min (Optional[int]): Minimum possible value of given symbol (inclusive) max (Optional[int]): Maximum possible value of given symbol (inclusive) Returns: A type that can be used in dynamic shape specifications for tensors. """_min=2ifminisNoneelsebuiltins.max(min,2)_max=sys.maxsize-1ifmaxisNoneelsebuiltins.min(max,sys.maxsize-1)assert_max>_min,f"Cannot create Dim with inconsistent min={min}, max={max}"dim=_Dim(name,(int,),{"min":_min,"max":_max})dim.__module__=getattr(inspect.getmodule(inspect.stack()[1][0]),"__name__","__main__")returndim
[docs]defdims(*names:str,min:Optional[int]=None,max:Optional[int]=None):""" Util to create multiple :func:`Dim` types. """returntuple(Dim(name,min=min,max=max)fornameinnames)
@dataclasses.dataclassclass_ConstraintTarget:""" This represents input tensor dimensions. Don't create this class directly; instead, use :func:`dynamic_dim`. """w_tensor:Any# weakref to torch.Tensor# TODO: We don't need t_id; we can get it off of w_tensort_id:intdim:intclass_ConstraintFactory(type):""" Metaclass that ensures a private constructor for :class:`_Constraint` """def__call__(cls,*args,**kwargs):raiseTypeError(f"{cls.__module__}.{cls.__qualname__} has no public constructor. "f"Please use torch.export.dynamic_dim() to create one")def_create(cls,w_tensor,t_id,dim,constraint_range,shared=None,debug_name=None):returnsuper().__call__(w_tensor,t_id,dim,constraint_range,shared,debug_name)def_create_constraint(w_tensor,t_id,dim,constraint_range,shared=None,debug_name=None):return_Constraint._create(w_tensor,t_id,dim,constraint_range,shared,debug_name)@dataclasses.dataclassclass_Constraint(_ConstraintTarget,metaclass=_ConstraintFactory):""" .. warning:: Do not construct :class:`_Constraint` directly, use :func:`dynamic_dim` instead. This represents constraints on input tensor dimensions, e.g., requiring them to be fully polymorphic or within some range. """# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]constraint_range:"StrictMinMaxConstraint"# Represent that `constraint_range` is shared with another _ConstraintTarget, which# typically arises because of a specified equality with another dynamic dimension.shared:Optional[_ConstraintTarget]=Nonedebug_name:Optional[str]=Nonedef_clone_with_range(self,lower=2,upper=math.inf):# Import sympy locallyfromtorch.fx.experimental.symbolic_shapesimportStrictMinMaxConstraintfromtorch.utils._sympy.value_rangesimportValueRangesconstraint_range=StrictMinMaxConstraint(vr=self.constraint_range.vr&ValueRanges(lower=lower,upper=upper),warn_only=False,)return_create_constraint(self.w_tensor,self.t_id,self.dim,constraint_range,self.shared,self.debug_name,)def__ge__(self,lower):returnself._clone_with_range(lower=lower)def__gt__(self,lower):returnself._clone_with_range(lower=lower+1)def__le__(self,upper):returnself._clone_with_range(upper=upper)def__lt__(self,upper):returnself._clone_with_range(upper=upper-1)def__bool__(self):# NOTE(avik): We do not support compound expressions like a <= x <= b.# This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),# and moreover, enforces that any overload of __bool__ must return True or False.# FWIW, sympy also raises TypeError in this case.raiseTypeError("Cannot determine truth value of _Constraint. ""If you are trying to combine _Constraint's with logical connectives, ""you can specify them separately instead.")@propertydefserializable_spec(self):# We need a serialization compatible format of the constraint so that it# can be savedin the graph module w/o breaking the module serialization.# The saved constraints will be used directly for the post-exporting pass# that converts constraints to runtime assertion. The saved constraints# will not be saved in the serialized module.# TODO: A better way is needed. Currently we use 't_id' to map the constraint,# which is not reliablereturn{"t_id":self.t_id,"dim":self.dim,"min":self.constraint_range.vr.lower,"max":self.constraint_range.vr.upper,}def__eq__(self,other):ifnotisinstance(other,_Constraint):raiseTypeError("A dynamic dim can be specified equal only to another dynamic dim. "f"Equality with {type(other)} is not supported.")# import sympy locallyfromtorch.fx.experimental.symbolic_shapesimportStrictMinMaxConstraintconstraint_range=StrictMinMaxConstraint(vr=self.constraint_range.vr&other.constraint_range.vr,warn_only=False,)ifself.debug_nameisNone:debug_name=other.debug_nameelse:assertother.debug_nameisNoneorself.debug_name==other.debug_namedebug_name=self.debug_namereturn_create_constraint(self.w_tensor,self.t_id,self.dim,constraint_range,shared=_ConstraintTarget(other.w_tensor,other.t_id,other.dim),debug_name=debug_name,)@dataclasses.dataclassclass_PhantomRoot:""" This represents the root of a derived Dim where the root does not directly specify the shape of any input dimension, but the derived Dim does. e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. The fields `name`, `constraint_range`, and `val` carried by a phantom root help create a symbol for it. Any derived dims with this phantom root are backed by expressions over this symbol. """name:strconstraint_range:"StrictMinMaxConstraint"val:int@dataclasses.dataclassclass_DerivedConstraint(_ConstraintTarget):""" This represents a derived Dim, whose root is either a regular constraint target (which directly specifies the shape of some input dimension) or a phantom root (which does so indirectly). """# NOTE: This is not currently a subclass of _Constraint because we do not support# `shared` for derived `Dim`s. Indeed, sharing is a necessary concept only for# legacy constraints based on `dynamic_dim`: equality can be expressed simply by# reusing the same (derived or normal) `Dim`.root:Union[_ConstraintTarget,_PhantomRoot]fn:Callableconstraint_range:"StrictMinMaxConstraint"debug_name:Optional[str]=None@propertydefshared(self):# Some code paths expect a union of _Constraint and _DerivedConstraint.# Thus we expose a `shared` field that is always None.# TODO(avik): clean this upreturnNone@propertydefserializable_spec(self):# same as _Constraint.serializable_specreturn{"t_id":self.t_id,"dim":self.dim,"min":self.constraint_range.vr.lower,"max":self.constraint_range.vr.upper,}Constraint=Union[_Constraint,_DerivedConstraint]
[docs]defdynamic_dim(t:torch.Tensor,index:int,debug_name:Optional[str]=None):""" .. warning:: (This feature is DEPRECATED. See :func:`Dim` instead.) :func:`dynamic_dim` constructs a :class:`_Constraint` object that describes the dynamism of a dimension ``index`` of tensor ``t``. :class:`_Constraint` objects should be passed to ``constraints`` argument of :func:`export`. Args: t (torch.Tensor): Example input tensor that have dynamic dimension size(s) index (int): Index of dynamic dimension Returns: A :class:`_Constraint` object that describes shape dynamism. It can be passed to :func:`export` so that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic as a symbolic size rather than specializing according to size of example tracing input. Specifically :func:`dynamic_dim` can be used to express following types of dynamism. - Size of a dimension is dynamic and unbounded:: t0 = torch.rand(2, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size rather than always being static size 2 constraints = [dynamic_dim(t0, 0)] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with a lower bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) constraints = [ dynamic_dim(t0, 0) >= 5, dynamic_dim(t1, 1) > 2, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with an upper bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) constraints = [ dynamic_dim(t0, 0) <= 16, dynamic_dim(t1, 1) < 8, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # Sizes of second dimension of t0 and first dimension are always equal constraints = [ dynamic_dim(t0, 1) == dynamic_dim(t1, 0), ] ep = export(fn, (t0, t1), constraints=constraints) - Mix and match all types above as long as they do not express conflicting requirements """fromtorch._dynamo.excimportUserError,UserErrorTypeifnotisinstance(t,torch.Tensor):raiseUserError(UserErrorType.DYNAMIC_DIM,f"Expected tensor as input to dynamic_dim but got {type(t)}",)ift.dim()<1:raiseUserError(UserErrorType.DYNAMIC_DIM,"Cannot mark 0-dimension tensors to be dynamic")ifindex>=t.dim():raiseUserError(UserErrorType.DYNAMIC_DIM,f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]"f" but got {index}, which is out of bounds for the given tensor.",)# Import sympy locallyimportsympyfromtorch.fx.experimental.symbolic_shapesimportStrictMinMaxConstraintfromtorch.utils._sympy.value_rangesimportValueRangesreturn_create_constraint(weakref.ref(t),id(t),index,StrictMinMaxConstraint(vr=ValueRanges(lower=2,upper=sympy.oo),warn_only=False),debug_name=debug_name,)
def_process_equalities(constraint:Constraint,get_sources:Callable[[int,int],List["Source"]],shape_env:"ShapeEnv",source_pairs:List[Tuple["Source","Source"]],derived_equalities:List[Tuple["Source",Union["Source","Symbol"],Callable]],phantom_symbols:Dict[str,"Symbol"],):""" Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become fields of `EqualityConstraint`) based on a given input `constraint`. """source,*other_sources=get_sources(constraint.t_id,constraint.dim)# When t.size()[dim] maps to src0, src1, ..., srcN, we add# constraints that make src0 "equal" to src1, ..., srcN.source_pairs.extend((source,other_source)forother_sourceinother_sources)ifnotisinstance(constraint,_DerivedConstraint):ifconstraint.sharedisnotNone:# Moreover, when t.size()[dim] is specified equal to t'.size()[dim']# and t'.size()[dim'] maps to src1', ..., srcN', we add# constraints that also make src0 "equal" to src1', ..., srcN'.other_sources=get_sources(constraint.shared.t_id,constraint.shared.dim)source_pairs.extend((source,other_source)forother_sourceinother_sources)else:# branch based on the root of the _DerivedConstraintifnotisinstance(constraint.root,_PhantomRoot):# either root points to an input sourceroot=get_sources(constraint.root.t_id,constraint.root.dim)[0]# type: ignore[assignment]else:# or root points to a phantom symbolifconstraint.root.nameinphantom_symbols:root=phantom_symbols[constraint.root.name]# type: ignore[assignment]else:# create a phantom symbol in the shape env based on the _PhantomRootroot=shape_env.create_symbol(val=constraint.root.val,source=torch._dynamo.source.ConstantSource(constraint.root.name),dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC,constraint_dim=constraint.root.constraint_range,)phantom_symbols[constraint.root.name]=root# type: ignore[assignment]fn=constraint.fn# A derived equality (source, root, fn) informally corresponds to source = fn(root).# Here source describes an input and root might describe another input or a phantom symbol.derived_equalities.append((source,root,fn))def_process_dynamic_shapes(f:Callable,args:Tuple[Any,...],kwargs:Optional[Dict[str,Any]]=None,dynamic_shapes:Optional[Union[Dict[str,Any],Tuple[Any],List[Any]]]=None,)->Optional[List[Constraint]]:fromcollectionsimportdefaultdictfromcollections.abcimportMapping,Sequencefromtorch._dynamo.excimportUserError,UserErrorTypeifdynamic_shapesisNoneorlen(dynamic_shapes)==0:returnNonekwargs=kwargsifkwargsisnotNoneelse{}deftree_zip(combined_args,dynamic_shapes):ifisinstance(combined_args,(tuple,list)):ifnotisinstance(dynamic_shapes,Sequence):raiseUserError(UserErrorType.INVALID_INPUT,f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, "f"got {dynamic_shapes} instead",)iflen(combined_args)!=len(dynamic_shapes):raiseUserError(UserErrorType.INVALID_INPUT,f"Expected {dynamic_shapes} to have {len(combined_args)} items",)fori,shapeinenumerate(dynamic_shapes):yield fromtree_zip(combined_args[i],shape)elifisinstance(combined_args,dict):ifnotisinstance(dynamic_shapes,Mapping):raiseUserError(UserErrorType.INVALID_INPUT,f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, "f"got {dynamic_shapes} instead",)iflen(combined_args)!=len(dynamic_shapes):raiseUserError(UserErrorType.INVALID_INPUT,f"Expected {dynamic_shapes} to have {len(combined_args)} items",)fork,shapeindynamic_shapes.items():yield fromtree_zip(combined_args[k],shape)eliftype(combined_args)inSUPPORTED_NODES:ifnotisinstance(dynamic_shapes,Sequence):raiseUserError(UserErrorType.INVALID_INPUT,f"Expected dynamic_shapes of a user-registered class (e.g., "f"{type(combined_args)}) to be a Sequence that matches the "f"flattened structure, but got {dynamic_shapes} instead",)yield fromtree_zip(SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0],dynamic_shapes,)elifisinstance(combined_args,torch.Tensor):yield(combined_args,dynamic_shapes)else:ifdynamic_shapesisnotNone:raiseUserError(UserErrorType.INVALID_INPUT,f"Expected dynamic_shapes of a {type(combined_args)} to be None, "f"got {dynamic_shapes} instead",)# map of Dim names representing input shape dimensions to constraints on themsymbols:Dict[str,List[Constraint]]=defaultdict(list)# track roots that do not directly represent input shape dimensionsphantom_roots:Dict[str,_PhantomRoot]={}derived_constraints_with_phantom_root:List[_DerivedConstraint]=[]defto_constraint(dim,tensor,i):importsympyfromtorch.fx.experimental.symbolic_shapesimportStrictMinMaxConstraintfromtorch.utils._sympy.solveimporttry_solvefromtorch.utils._sympy.value_rangesimportValueRangesdefroot_value():# given tensor.shape[i] is the value of dim = fn(root),# find the value of rootsymbol=sympy.Symbol(dim.root.__name__,integer=True)expr=dim.fn(symbol)solution=try_solve(sympy.Eq(expr,tensor.shape[i]),symbol)ifsolutionisnotNone:returnint(solution[1])# type: ignore[call-overload]else:raiseUserError(# noqa: TRY200UserErrorType.CONSTRAINT_VIOLATION,f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be "f"of the form {expr}, where {symbol} is an integer",)ifisinstance(dim,_DerivedDim):# generate a _DerivedConstraint where the root is:# - either a _ConstraintTarget (if dim.root directly describes an input shape)# - or a _PhantomRoot (otherwise)dim_root=dim.root# type: ignore[attr-defined]ifdim_root.__name__insymbols:# root represents an input shape dimensionroot_constraint=symbols[dim_root.__name__][0]root=_ConstraintTarget(root_constraint.w_tensor,root_constraint.t_id,root_constraint.dim,)elifdim_root.__name__notinphantom_roots:# create a phantom rootroot=_PhantomRoot(# type: ignore[assignment]name=dim_root.__name__,constraint_range=StrictMinMaxConstraint(vr=ValueRanges(lower=dim_root.min,upper=dim_root.max),warn_only=False,),val=root_value(),)phantom_roots[dim_root.__name__]=root# type: ignore[assignment]else:root=phantom_roots[dim_root.__name__]# type: ignore[assignment]constraint=_DerivedConstraint(weakref.ref(tensor),id(tensor),i,root,dim.fn,# type: ignore[attr-defined]StrictMinMaxConstraint(vr=ValueRanges(lower=dim.min,upper=dim.max),warn_only=False,),debug_name=dim.__name__,)ifisinstance(root,_PhantomRoot):# NOTE(avik): since we have not processed all inputs yet, we may replace this# with a root that does represent an input shape dimension later (see below)derived_constraints_with_phantom_root.append(constraint)else:constraint=dynamic_dim(tensor,i,debug_name=dim.__name__)ifdim.min!=2:constraint=constraint>=dim.minifdim.max!=sys.maxsize-1:constraint=constraint<=dim.maxreturnconstraintbounds:Dict[str,Tuple[int,int]]={}defcheck_same_bounds(dim):ifdim.__name__insymbols:min_,max_=bounds[dim.__name__]ifdim.min!=min_ordim.max!=max_:this_=_Dim.readable(dim.__name__,min_,max_)that_=_Dim.readable(dim.__name__,dim.min,dim.max)raiseUserError(UserErrorType.INVALID_INPUT,f"Found different definitions {this_} and {that_} "f"for the same symbolic dimension {dim}!",)else:bounds[dim.__name__]=(dim.min,dim.max)defupdate_symbols(tensor,shape):ifisinstance(shape,dict):fori,diminshape.items():ifisinstance(dim,_Dim):check_same_bounds(dim)constraint=to_constraint(dim,tensor,i)symbols[dim.__name__].append(constraint)else:ifdimisnotNone:raiseUserError(UserErrorType.INVALID_INPUT,f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, ""try None instead",)elifisinstance(shape,(tuple,list)):fori,diminenumerate(shape):ifisinstance(dim,_Dim):check_same_bounds(dim)constraint=to_constraint(dim,tensor,i)symbols[dim.__name__].append(constraint)else:ifdimisnotNone:raiseUserError(UserErrorType.INVALID_INPUT,f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, ""try None instead",)else:ifshapeisnotNone:raiseUserError(UserErrorType.INVALID_INPUT,f"Unexpected dynamic_shape {shape} of Tensor, ""try None instead",)importinspectifisinstance(f,ExportedProgram):f=f.module()signature=(inspect.signature(f.forward)ifisinstance(f,torch.nn.Module)elseinspect.signature(f))combined_args=signature.bind(*args,**kwargs).arguments# This means user didn't specify dynamic shapes with argument names.combined_args=combined_argsifisinstance(dynamic_shapes,Mapping)elselist(combined_args.values())# type: ignore[assignment]fortensor,shapeintree_zip(combined_args,dynamic_shapes):update_symbols(tensor,shape)constraints=[]forderived_constraint_with_phantom_rootinderived_constraints_with_phantom_root:phantom_root_name=derived_constraint_with_phantom_root.root.name# type: ignore[union-attr]ifphantom_root_nameinsymbols:# We found an input shape dimension corresponding to this name, so we# do not need a phantom symbol for it after all.# NOTE(avik): Overall we want to maintain the invariant that roots that# are phantom symbols are really "phantom," i.e., they cannot be represented# by any input source. This is important when we are deciding derived equalities,# since we can focus our attention exclusively on input sources: deciding# derived equalities involving phantom symbols are, in comparison, trivial.derived_constraint_with_phantom_root.root=symbols[phantom_root_name][0]fordynamic_dimsinsymbols.values():ifall(isinstance(dynamic_dim,_DerivedConstraint)fordynamic_dimindynamic_dims):constraints.extend(dynamic_dims)else:primary,*others=dynamic_dimsifothers:forotherinothers:constraints.append(primary==other)# type: ignore[arg-type]else:constraints.append(primary)returnconstraints# type: ignore[return-value]def_process_constraints(fake_mode,graph_module:torch.fx.GraphModule,num_lifted_params_buffers:int,example_inputs:List[torch.Tensor],)->Dict:""" Process the constraints stored in the graph module to return something more readable. Args: graph_module (torch.fx.GraphModule): GraphModule returned from dynamo.export, which contains the "input_shape_constraints" and "inline_constraints" metadata example_inputs: Flattened list of example inputs used to export the graph module Returns: range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of symbols (from SymInts) appearing in the fake tensors in node.meta["val"] to their range constraints, which are a tuple containing (lower, upper) constraints. """fromtorch._export.passes.add_runtime_assertions_for_constraints_passimport(InputDim,)# Import sympy locallyfromtorch.fx.experimental.symbolic_shapesimportSymIntfromtorch.utils._sympy.value_rangesimportValueRangesinput_shape_constraints=graph_module.meta.get("input_shape_constraints",[])inline_constraints=graph_module.meta.get("inline_constraints",[])# Create dict mapping tensor_id to node namestensor_id_to_nodes:Dict[int,List[str]]=defaultdict(list)# Create dict mapping placeholder node names to their nodesplaceholder_nodes:Dict[str,torch.fx.Node]={}fori,nodeinenumerate(graph_module.graph.nodes):ifnode.op!="placeholder":# All placeholder nodes should be together in the beginning of the# graphbreakifi>=num_lifted_params_buffers:example_input=example_inputs[i-num_lifted_params_buffers]tensor_id_to_nodes[id(example_input)].append(node.name)placeholder_nodes[node.name]=node# Create dict mapping (node name, dim) a list of range (lower, upper)# constraintsmulti_range_constraints:Dict[InputDim,List[ValueRanges]]=defaultdict(list)forconstraintininput_shape_constraints:fornodeintensor_id_to_nodes[constraint["t_id"]]:node_dim=InputDim(node,constraint["dim"])# Accumulate range constraintsmulti_range_constraints[node_dim].append(ValueRanges(constraint["min"],constraint["max"]))# Create dict mapping symbol to a singular range (lower, upper)range_constraints:Dict[Any,ValueRanges]={}# Add inline constraints to range_constraintsrange_constraints={symbol:inline_constraints[symbol]forsymbolininline_constraints}free_symbols:Set["Symbol"]=set()# Add input range constraints to range_constraintsforinput_dim,multi_range_constraintinmulti_range_constraints.items():# type: ignore[assignment]# Simplify the range constraints into a single range constraint# Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10]min_vals=[rc.lowerforrcinmulti_range_constraint]max_vals=[rc.upperforrcinmulti_range_constraint]min_val=max(min_vals)# type: ignore[type-var]max_val=min(max_vals)# type: ignore[type-var]assertmin_val<=max_val# type: ignore[operator]# Add input node range constraintsval=placeholder_nodes[input_dim.input_name].meta["val"]assertisinstance(val,FakeTensor)symint=val.shape[input_dim.dim]assertisinstance(symint,SymInt),f"Expected SymInt but got {symint}: {type(symint)}"symbol=symint.node.exprrange_constraints[symbol]=ValueRanges(min_val,max_val)free_symbols.update(symbol.free_symbols)forsymbolinfree_symbols:ifsymbolnotinrange_constraints:# Placeholders can have symbolic shapes that are derived expressions.# The above code will record direct range constraints for them# so that we can do runtime assertions. In addition, for serde checks# we want to record range constraints for their root symbols.range_constraints[symbol]=fake_mode.shape_env.var_to_range[symbol]returnrange_constraints
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.