# mypy: allow-untyped-defsimportdataclassesimportinspectimportloggingimportsysfromcollectionsimportdefaultdictfromenumimportauto,EnumfromtypingimportAny,Callable,Dict,List,Optional,Set,Tuple,TYPE_CHECKING,Unionimporttorchfromtorch.utils._pytreeimport(_get_node_type,BUILTIN_TYPES,keystr,LeafSpec,MappingKey,SequenceKey,SUPPORTED_NODES,tree_flatten,tree_map_with_path,)from.exported_programimportExportedProgramifTYPE_CHECKING:fromsympyimportSymbolfromtorch._guardsimportSourcefromtorch.fx.experimental.symbolic_shapesimportShapeEnv,StrictMinMaxConstraint__all__=["Constraint","Dim","dims","refine_dynamic_shapes_from_suggested_fixes",]log=logging.getLogger(__name__)class_DimHint(Enum):""" Enum for dynamic shape hints. - AUTO means automatic inference of shape (static or dynamic). - STATIC means static shape (always specialized). - DYNAMIC means dynamic, will error out if specialized. """AUTO=auto()STATIC=auto()DYNAMIC=auto()class_Dim(type):""" Metaclass for :func:`Dim` types. """@staticmethoddefreadable(name,min_,max_):fromtorch.utils._sympy.numbersimportint_ooifmin_==2:min_=Noneifmax_==int_oo:max_=Noneifmin_isNoneandmax_isNone:returnf"Dim('{name}')"ifmin_isNone:returnf"Dim('{name}', max={max_})"ifmax_isNone:returnf"Dim('{name}', min={min_})"returnf"Dim('{name}', min={min_}, max={max_})"def__add__(cls,other):# e.g., dim + 1iftype(other)isnotint:raiseNotImplementedError(f"Attempted to add {other} to {cls.__name__}, where an integer was expected. ""(Only increasing linear operations with integer coefficients are supported.)")returncls._derive(lambdax:x+other)def__radd__(cls,other):returncls+otherdef__sub__(cls,other):# e.g., dim - 1iftype(other)isnotint:raiseNotImplementedError(f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. ""(Only increasing linear operations with integer coefficients are supported.)")returncls._derive(lambdax:x-other)def__rsub__(cls,other):raiseNotImplementedError(f"Attempted to negate {cls.__name__}. ""(Only increasing linear operations with integer coefficients are supported.)")def__mul__(cls,other):# e.g., dim * 2iftype(other)isnotintorother<=0:raiseNotImplementedError(f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. ""(Only increasing linear operations with integer coefficients are supported.)")returncls._derive(lambdax:x*other)def__rmul__(cls,other):returncls*otherdef_derived_name(cls,fn):fromsympyimportsympifyreturnstr(fn(sympify(cls.__name__)))def_derive(cls,fn):return_DerivedDim(cls._derived_name(fn),(int,),{"root":cls,"fn":fn})class_StaticDim(_Dim):""" Meta class for static :func:`Dim` types. This class is only for setting and checking static dim constraints, and the user should never interact with it. """@propertydefmin(self):returnself.value# type: ignore[attr-defined]@propertydefmax(self):returnself.value# type: ignore[attr-defined]class_DerivedDim(_Dim):""" Metaclass for derived :func:`Dim` types. Currently we only support increasing linear expressions with integer coefficients. In other words, a derived Dim can always be written in the form Ax + B, where x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. (In particular, the latter ensures that x < y => Ax + B < Ay + B.) These restrictions on the form of derived Dims makes the metatheory simpler: e.g., it simplifies computing ranges for derived Dims, solving for underlying regular Dims, deciding equalities between derived Dims, and so on. The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. The range of a derived Dim is computed by mapping `fn` over the range of its `root`. """@propertydefmin(self):# assume that self.fn is an increasing function# TODO(avik): use sympy value range analysis instead?fromsympyimportIntegerfromtorch.utils._sympy.numbersimportint_ooifself.root.minis-int_oo:# type: ignore[attr-defined]return-int_oo# fn not needed cuz increasing_min_symint=self.fn(Integer(self.root.min))# type: ignore[attr-defined]root=self.root# type: ignore[attr-defined]assert_min_symint>=0,(f"Expected derived min value of {self.__name__} to be >= 0. "f"Please specify an appropriate min value for {root.__name__} "f"(currently {root.min}).")returnint(_min_symint)@propertydefmax(self):# assume that self.fn is an increasing function# TODO(avik): use sympy value range analysis instead?fromsympyimportIntegerfromtorch.utils._sympy.numbersimportint_ooifself.root.maxisint_oo:# type: ignore[attr-defined]returnint_oo# fn not needed cuz increasing_max_symint=self.fn(Integer(self.root.max))# type: ignore[attr-defined]root=self.root# type: ignore[attr-defined]assert_max_symint<=sys.maxsize-1,(f"Expected derived max value of {self.__name__} to be <= {sys.maxsize-1}. "f"Please specify an appropriate max value for {root.__name__} "f"(currently {root.max}).")returnint(_max_symint)def_derive(self,fn):# We support nesting, e.g., 2*dim + 1.# This is implemented by composing operations on the same root.# As a consequence, roots are always regular Dims (i.e., not derived Dims).return_DerivedDim(self._derived_name(fn),(int,),{"root":self.root,"fn":lambdax:fn(self.fn(x))},# type: ignore[attr-defined])
[docs]defDim(name:str,*,min:Optional[int]=None,max:Optional[int]=None):""" :func:`Dim` constructs a type analogous to a named symbolic integer with a range. It can be used to describe multiple possible values of a dynamic tensor dimension. Note that different dynamic dimensions of the same tensor, or of different tensors, can be described by the same type. Args: name (str): Human-readable name for debugging. min (Optional[int]): Minimum possible value of given symbol (inclusive) max (Optional[int]): Maximum possible value of given symbol (inclusive) Returns: A type that can be used in dynamic shape specifications for tensors. """fromtorch.utils._sympy.numbersimportint_oo_min=0ifminisNoneelsemin_max=int_ooifmaxisNoneelsemaxassert_max>_min,f"Cannot create Dim with inconsistent min={min}, max={max}"assertname.isidentifier(),f"Dim name must be a valid identifier, got {name}"dim=_Dim(name,(int,),{"min":_min,"max":_max})dim.__module__=getattr(inspect.getmodule(inspect.stack()[1][0]),"__name__","__main__")returndim
[docs]defdims(*names:str,min:Optional[int]=None,max:Optional[int]=None):""" Util to create multiple :func:`Dim` types. """returntuple(Dim(name,min=min,max=max)fornameinnames)
@dataclasses.dataclassclass_ConstraintTarget:""" This represents input tensor dimensions. """t_id:intdim:int@dataclasses.dataclassclass_Constraint(_ConstraintTarget):""" This represents a Dim describing a constraint target. `name` is the name of the Dim. `constraint_range` contains the min/max bounds of the Dim. """name:strconstraint_range:"StrictMinMaxConstraint"def_clone_with_range(self,lower=0,upper=None):# Import sympy locallyfromtorch.fx.experimental.symbolic_shapesimportStrictMinMaxConstraintfromtorch.utils._sympy.numbersimportint_oofromtorch.utils._sympy.value_rangesimportValueRangesifupperisNone:upper=int_ooconstraint_range=StrictMinMaxConstraint(vr=self.constraint_range.vr&ValueRanges(lower=lower,upper=upper),warn_only=False,)return_Constraint(self.t_id,self.dim,self.name,constraint_range,)def__ge__(self,lower):returnself._clone_with_range(lower=lower)def__gt__(self,lower):returnself._clone_with_range(lower=lower+1)def__le__(self,upper):returnself._clone_with_range(upper=upper)def__lt__(self,upper):returnself._clone_with_range(upper=upper-1)def__bool__(self):# NOTE(avik): We do not support compound expressions like a <= x <= b.# This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),# and moreover, enforces that any overload of __bool__ must return True or False.# FWIW, sympy also raises TypeError in this case.raiseTypeError("Cannot determine truth value of _Constraint. ""If you are trying to combine _Constraint's with logical connectives, ""you can specify them separately instead.")@propertydefserializable_spec(self):# We need a serialization compatible format of the constraint so that it# can be savedin the graph module w/o breaking the module serialization.# The saved constraints will be used directly for the post-exporting pass# that converts constraints to runtime assertion. The saved constraints# will not be saved in the serialized module.# TODO: A better way is needed. Currently we use 't_id' to map the constraint,# which is not reliablereturn{"t_id":self.t_id,"dim":self.dim,"min":self.constraint_range.vr.lower,"max":self.constraint_range.vr.upper,}@dataclasses.dataclassclass_PhantomRoot:""" This represents the root of a derived Dim where the root does not directly specify the shape of any input dimension, but the derived Dim does. e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. The fields `name`, `constraint_range`, and `val` carried by a phantom root help create a symbol for it. Any derived dims with this phantom root are backed by expressions over this symbol. """name:strconstraint_range:"StrictMinMaxConstraint"val:int@dataclasses.dataclassclass_DerivedConstraint(_ConstraintTarget):""" This represents a derived Dim, whose root is either a regular constraint target (which directly specifies the shape of some input dimension) or a phantom root (which does so indirectly). It can be thought of as a subclass of `_Constraint`, except that it does not support <, <=, >, >= operations. """name:strconstraint_range:"StrictMinMaxConstraint"root:Union[_ConstraintTarget,_PhantomRoot]fn:Callable@propertydefserializable_spec(self):# same as _Constraint.serializable_specreturn{"t_id":self.t_id,"dim":self.dim,"min":self.constraint_range.vr.lower,"max":self.constraint_range.vr.upper,}@dataclasses.dataclassclass_RelaxedConstraint(_ConstraintTarget):""" This represents a dim marked with Dim.AUTO/DYNAMIC (i.e. mark_dynamic() or maybe_mark_dynamic()), which leaves relations & min/max ranges for inference, instead of requiring explicit specification. The intention is for constraint violations to not be raised if produce_guards() finds equalities or relations between a _RelaxedConstraint and another type of _Constraint. """@propertydefserializable_spec(self):return{"t_id":self.t_id,"dim":self.dim,}Constraint=Union[_Constraint,_DerivedConstraint,_RelaxedConstraint]def_process_equalities(constraint:Constraint,get_sources:Callable[[int,int],List["Source"]],shape_env:"ShapeEnv",names:Dict[str,Tuple[int,int]],source_pairs:List[Tuple["Source","Source"]],derived_equalities:List[Tuple["Source",Union["Source","Symbol"],Callable]],phantom_symbols:Dict[str,"Symbol"],relaxed_sources:Set["Source"],):""" Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become fields of `EqualityConstraint`) based on a given input `constraint`. """sources=get_sources(constraint.t_id,constraint.dim)ifnotsources:# empty sources due to unused shapesreturnsource,*other_sources=sources# When t.size()[dim] maps to src0, src1, ..., srcN, we add# constraints that make src0 "equal" to src1, ..., srcN.source_pairs.extend((source,other_source)forother_sourceinother_sources)ifisinstance(constraint,_Constraint):ifconstraint.nameinnames:shared_t_id,shared_dim=names[constraint.name]other_sources=get_sources(shared_t_id,shared_dim)source_pairs.extend((source,other_source)forother_sourceinother_sources)else:names[constraint.name]=(constraint.t_id,constraint.dim)elifisinstance(constraint,_DerivedConstraint):# branch based on the root of the _DerivedConstraintifnotisinstance(constraint.root,_PhantomRoot):# either root points to an input sourceroot=get_sources(constraint.root.t_id,constraint.root.dim)[0]else:# or root points to a phantom symbolifconstraint.root.nameinphantom_symbols:root=phantom_symbols[constraint.root.name]else:# create a phantom symbol in the shape env based on the _PhantomRootroot=shape_env.create_symbol(val=constraint.root.val,source=torch._dynamo.source.ConstantSource(constraint.root.name),dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC,constraint_dim=constraint.root.constraint_range,)phantom_symbols[constraint.root.name]=rootfn=constraint.fn# A derived equality (source, root, fn) informally corresponds to source = fn(root).# Here source describes an input and root might describe another input or a phantom symbol.derived_equalities.append((source,root,fn))elifisinstance(constraint,_RelaxedConstraint):relaxed_sources.add(source)def_tree_map_with_path(func:Callable[...,Any],tree:Any,*dynamic_shapes:Any,tree_name:Optional[str]=None,)->Any:""" Customized tree_map for mapping pytrees to dynamic_shapes. For built-in types (e.g., standard collections) this behaves exactly like tree_map. OTOH for a user-defined class C registered with pytree, we cannot assume that a C containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not be a polymorphic container). In that case we use the flattened form of C instead. Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes). Args: func: function to apply to each (int, float, str, bool, None, torch.Tensor) tree: input pytree dynamic_shapes: zero or more (typically one) dynamic_shapes to match Returns: output pytree mapping func to each (int, float, str, bool, None, torch.Tensor) """defis_leaf(t):# BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types# registered with pytree. Types *not* in BUILTIN_TYPES include primitive types# (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES,# as well as user-defined classes registered with pytree, which are.return_get_node_type(t)notinBUILTIN_TYPESdeff(path,t,*dynamic_shapes):typ=_get_node_type(t)# typ is not in BUILTIN_TYPESiftypinSUPPORTED_NODES:# thus typ is a user-defined class registered with pytree,# in which case flatten and recursereturntree_map_with_path(f,SUPPORTED_NODES[typ].flatten_fn(t)[0],*dynamic_shapes,is_leaf=is_leaf,)else:returnfunc(path,t,*dynamic_shapes)try:returntree_map_with_path(f,tree,*dynamic_shapes,is_leaf=is_leaf)exceptValueErrorase:if"mismatch"ine.args[0]:# When PyTree finds a structural mismatch between tree and dynamic_shapes,# the error message is unfortunately quite horrible. Let's fix that.assertdynamic_shapes,"Cannot be a mismatch if there is no dynamic_shapes"asserttree_name,"Must provide a tree_name when there might be a mismatch"def_key(type_,context,i):# derive a PyTree key given the type, context, and child # of a TreeSpeciftype_isdict:returnMappingKey(context[i])iftype_in(list,tuple):assertcontextisNonereturnSequenceKey(i)raiseAssertionError(f"Did not expect type {type_}")defraise_mismatch_error(msg):fromtorch._dynamo.excimportUserError,UserErrorTyperaiseUserError(UserErrorType.INVALID_INPUT,f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}",case_name="dynamic_shapes_validation",)def_compare(tree,dynamic_shapes,path):# raise an error at the point where tree and dynamic_shapes differ,# including the path to that point and the reason for the differencerendered_path=keystr(path)ifisinstance(tree,LeafSpec):returnifisinstance(dynamic_shapes,LeafSpec):raise_mismatch_error(f"`{tree_name}{rendered_path}` is a {tree.type}, "f"but `dynamic_shapes{rendered_path}` is not")iftree.type!=dynamic_shapes.type:raise_mismatch_error(f"`{tree_name}{rendered_path}` is a {tree.type}, "f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}")iflen(tree.children_specs)!=len(dynamic_shapes.children_specs):raise_mismatch_error(f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, "f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements")iftree.typeisdict:# context, children could be out of orderifsorted(tree.context)!=sorted(dynamic_shapes.context):raise_mismatch_error(f"`{tree_name}{rendered_path}` has keys {tree.context}, "f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}")_remap=dict(zip(dynamic_shapes.context,dynamic_shapes.children_specs))dynamic_shapes_children_specs=[_remap[k]forkintree.context]else:dynamic_shapes_children_specs=dynamic_shapes.children_specsfori,(tree_,dynamic_shapes_)inenumerate(zip(tree.children_specs,dynamic_shapes_children_specs)):_compare(tree_,dynamic_shapes_,path+[_key(tree.type,tree.context,i)],)_,tree_spec=tree_flatten(tree,is_leaf=is_leaf)forother_treeindynamic_shapes:_,other_tree_spec=tree_flatten(other_tree,is_leaf)_compare(tree_spec,other_tree_spec,[])raisedef_combine_args(f,args,kwargs,_is_torch_jit_trace=False)->Dict[str,Any]:# combine args and kwargs following the signature of f, as it happens# in the body of f when called with *args, **kwargsifisinstance(f,ExportedProgram):f=f.module()ifnot_is_torch_jit_trace:signature=(inspect.signature(f.forward)ifisinstance(f,torch.nn.Module)elseinspect.signature(f))kwargs=kwargsifkwargsisnotNoneelse{}returnsignature.bind(*args,**kwargs).argumentsreturnargs
[docs]classShapesCollection:""" Builder for dynamic_shapes. Used to assign dynamic shape specifications to tensors that appear in inputs. Example:: args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # This is equivalent to the following (now auto-generated): # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} torch.export(..., args, dynamic_shapes=dynamic_shapes) """def__init__(self):self._shapes={}def__setitem__(self,t,shape):assertisinstance(t,torch.Tensor),f"Cannot assign shape to non-tensor type {type(t)}"# TODO(avik): check that shape is indeed a Shapet_id=id(t)ift_idinself._shapes:_shape=self._shapes[t_id]assert(shape==_shape),f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}"else:self._shapes[id(t)]=shapedef__getitem__(self,t):t_id=id(t)ift_idinself._shapes:returnself._shapes[t_id]else:returnNonedef__len__(self):returnlen(self._shapes)
[docs]defdynamic_shapes(self,m,args,kwargs=None):""" Generate dynamic_shapes. """t_ids=set()deffind_shape(path,t):t_id=id(t)ift_idinself._shapes:t_ids.add(t_id)returnself._shapes[t_id]else:returnNonecombined_args=_combine_args(m,args,kwargs)dynamic_shapes=_tree_map_with_path(find_shape,combined_args)ifany(t_idnotint_idsfort_idinself._shapes):raiseValueError("Some tensors that were assigned shapes were not found in args. ""Maybe such tensors were copied when passing them as args? ""Maybe such tensors are contained in classes that were not registered with pytree?")returndynamic_shapes
def_warn_on_None_dynamic_shape_dimension():msg=("Using None as a dynamic shape dimension is deprecated. ""Please use Dim.STATIC instead")# TODO(avik): raise an error in the futurelog.warning(msg)def_check_dynamic_shapes(combined_args:Dict[str,Any],dynamic_shapes:Union[Dict[str,Any],Tuple[Any],List[Any],None],):""" Checks the dynamic_shapes specification for correctness, using combined args + kwargs as reference for inputs structure. """fromtorch._dynamo.excimportUserError,UserErrorTypeifdynamic_shapesisNoneorlen(dynamic_shapes)==0:returnifisinstance(dynamic_shapes,(tuple,list)):combined_args=type(dynamic_shapes)(combined_args.values())# type: ignore[assignment, misc]bounds:Dict[str,Tuple[int,int]]={}defcheck_same_bounds(dim):ifdim.__name__inbounds:min_,max_=bounds[dim.__name__]ifdim.min!=min_ordim.max!=max_:this_=_Dim.readable(dim.__name__,min_,max_)that_=_Dim.readable(dim.__name__,dim.min,dim.max)raiseUserError(UserErrorType.INVALID_INPUT,f"Found different definitions {this_} and {that_} "f"for the same symbolic dimension {dim}!",)else:bounds[dim.__name__]=(dim.min,dim.max)defcheck_symbols(path,tensor,shape):ifisinstance(shape,dict):fori,diminshape.items():ifisinstance(dim,_Dim):check_same_bounds(dim)elifdimisNone:_warn_on_None_dynamic_shape_dimension()elifnot(isinstance(dim,(int,_DimHint))):raiseUserError(UserErrorType.INVALID_INPUT,f"Unexpected dimension mapped to index {i} in input tensor shape {shape} "f"specified at `dynamic_shapes{keystr(path)}` "f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, "f" but got {dim} instead)",case_name="dynamic_shapes_validation",)elifisinstance(shape,(tuple,list)):fori,diminenumerate(shape):ifisinstance(dim,_Dim):check_same_bounds(dim)elifdimisNone:_warn_on_None_dynamic_shape_dimension()elifnot(isinstance(dim,(int,_DimHint))):raiseUserError(UserErrorType.INVALID_INPUT,f"Unexpected dimension #{i} in input tensor shape {shape} "f"specified at `dynamic_shapes{keystr(path)}` "f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, "f"but got {dim} instead)",case_name="dynamic_shapes_validation",)elifshapeisnotNone:raiseUserError(UserErrorType.INVALID_INPUT,f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` "f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"f" where each dimension is an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC)",case_name="dynamic_shapes_validation",)assertisinstance(dynamic_shapes,(dict,tuple,list))ifisinstance(dynamic_shapes,dict):got_keys=list(dynamic_shapes.keys())expected_arg_names=list(combined_args.keys())ifsorted(got_keys)!=sorted(expected_arg_names):msg=(f"When `dynamic_shapes` is specified as a dict, its top-level keys "f"must be the arg names {expected_arg_names} of `inputs`, but "f"here they are {got_keys}. ")if(len(combined_args)==1andexpected_arg_names[0]notingot_keysandisinstance(combined_args[expected_arg_names[0]],dict)):msg+=("Since here `inputs` is a list/tuple enclosing a single dict, ""maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?")else:msg+=("Alternatively, you could also ignore arg names entirely ""and specify `dynamic_shapes` as a list/tuple matching `inputs`.")raiseUserError(UserErrorType.INVALID_INPUT,msg,case_name="dynamic_shapes_validation")defcheck_shape(path,t,dynamic_shape):ifisinstance(t,torch.Tensor):check_symbols(path,t,dynamic_shape)else:ifdynamic_shapeisnotNone:rendered_path=keystr(path)raiseUserError(UserErrorType.INVALID_INPUT,f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` "f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)",case_name="dynamic_shapes_validation",)_tree_map_with_path(check_shape,combined_args,dynamic_shapes,tree_name="inputs")def_process_dynamic_shapes(combined_args:Dict[str,Any],dynamic_shapes:Union[Dict[str,Any],Tuple[Any],List[Any],None],)->List[Constraint]:""" Reads the dynamic_shapes specification and produces a list of constraints. """fromtorch._dynamo.excimportUserError,UserErrorTypeifdynamic_shapesisNoneorlen(dynamic_shapes)==0:# we run with dynamic by default, so no need to produce constraintsreturn[]ifisinstance(dynamic_shapes,(tuple,list)):combined_args=type(dynamic_shapes)(combined_args.values())# type: ignore[assignment, misc]# map of Dim names representing input shape dimensions to constraints on themsymbols:Dict[str,List[Constraint]]=defaultdict(list)# track roots that do not directly represent input shape dimensionsphantom_roots:Dict[str,_PhantomRoot]={}derived_constraints_with_phantom_root:List[_DerivedConstraint]=[]# list of constraints to returnconstraints:List[Constraint]=[]defto_constraint(dim,tensor,i):importsympyfromtorch.fx.experimental.symbolic_shapesimportStrictMinMaxConstraintfromtorch.utils._sympy.solveimporttry_solvefromtorch.utils._sympy.value_rangesimportValueRangesdefroot_value():# given tensor.shape[i] is the value of dim = fn(root),# find the value of rootsymbol=sympy.Symbol(dim.root.__name__,integer=True)expr=dim.fn(symbol)solution=try_solve(sympy.Eq(expr,tensor.shape[i]),symbol)ifsolutionisnotNone:returnint(solution[1])else:raiseUserError(# noqa: B904UserErrorType.CONSTRAINT_VIOLATION,f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be "f"of the form {expr}, where {symbol} is an integer",)ifisinstance(dim,_DerivedDim):# generate a _DerivedConstraint where the root is:# - either a _ConstraintTarget (if dim.root directly describes an input shape)# - or a _PhantomRoot (otherwise)dim_root=dim.root# type: ignore[attr-defined]ifdim_root.__name__insymbols:# root represents an input shape dimensionroot_constraint=symbols[dim_root.__name__][0]root=_ConstraintTarget(root_constraint.t_id,root_constraint.dim,)elifdim_root.__name__notinphantom_roots:# create a phantom rootroot=_PhantomRoot(# type: ignore[assignment]name=dim_root.__name__,constraint_range=StrictMinMaxConstraint(vr=ValueRanges(lower=dim_root.min,upper=dim_root.max),warn_only=False,),val=root_value(),)phantom_roots[dim_root.__name__]=root# type: ignore[assignment]else:root=phantom_roots[dim_root.__name__]# type: ignore[assignment]constraint=_DerivedConstraint(id(tensor),i,dim.__name__,StrictMinMaxConstraint(vr=ValueRanges(lower=dim.min,upper=dim.max),warn_only=False,),root,dim.fn,# type: ignore[attr-defined])ifisinstance(root,_PhantomRoot):# NOTE(avik): since we have not processed all inputs yet, we may replace this# with a root that does represent an input shape dimension later (see below)derived_constraints_with_phantom_root.append(constraint)elifisinstance(dim,_StaticDim):constraint=_Constraint(# type: ignore[assignment]id(tensor),i,dim.__name__,StrictMinMaxConstraint(vr=ValueRanges(lower=dim.value,upper=dim.value),warn_only=False# type: ignore[attr-defined]),)else:assertisinstance(dim,_Dim)constraint=_Constraint(# type: ignore[assignment]id(tensor),i,dim.__name__,StrictMinMaxConstraint(vr=ValueRanges(lower=dim.min,upper=dim.max),warn_only=False# type: ignore[attr-defined]),)returnconstraintdefupdate_symbols(path,tensor,shape):def_create_static_dim(tensor,i,value):return_StaticDim(str(value),(int,),{"value":value})# clean out decorators from user side, or previous export call# we also delete these attributes in non_strict_utils.py/make_constraints()tensor._dynamo_weak_dynamic_indices=set()tensor._dynamo_dynamic_indices=set()tensor._dynamo_dynamic_range=set()tensor._dynamo_static_indices=set()tensor._dynamo_unbacked_indices=set()ifisinstance(shape,dict):fori,diminshape.items():ifisinstance(dim,(int,_Dim)):ifisinstance(dim,int):dim=_create_static_dim(tensor,i,dim)constraint=to_constraint(dim,tensor,i)symbols[dim.__name__].append(constraint)elifisinstance(dim,_DimHint):ifdim==_DimHint.AUTO:torch._dynamo.maybe_mark_dynamic(tensor,i)elifdim==_DimHint.STATIC:torch._dynamo.mark_static(tensor,i)elifdim==_DimHint.DYNAMIC:torch._dynamo.mark_dynamic(tensor,i)constraints.append(_RelaxedConstraint(id(tensor),i))elifdimisNone:torch._dynamo.mark_static(tensor,i)elifisinstance(shape,(tuple,list)):fori,diminenumerate(shape):ifisinstance(dim,(int,_Dim)):ifisinstance(dim,int):dim=_create_static_dim(tensor,i,dim)constraint=to_constraint(dim,tensor,i)symbols[dim.__name__].append(constraint)elifisinstance(dim,_DimHint):ifdim==_DimHint.AUTO:torch._dynamo.maybe_mark_dynamic(tensor,i)elifdim==_DimHint.STATIC:torch._dynamo.mark_static(tensor,i)elifdim==_DimHint.DYNAMIC:torch._dynamo.mark_dynamic(tensor,i)constraints.append(_RelaxedConstraint(id(tensor),i))elifdimisNone:torch._dynamo.mark_static(tensor,i)elifshapeisNone:foriinrange(tensor.dim()):torch._dynamo.mark_static(tensor,i)defassoc_shape(path,t,dynamic_shape):ifisinstance(t,torch.Tensor):update_symbols(path,t,dynamic_shape)_tree_map_with_path(assoc_shape,combined_args,dynamic_shapes,tree_name="inputs")forderived_constraint_with_phantom_rootinderived_constraints_with_phantom_root:phantom_root_name=derived_constraint_with_phantom_root.root.name# type: ignore[union-attr]ifphantom_root_nameinsymbols:# We found an input shape dimension corresponding to this name, so we# do not need a phantom symbol for it after all.# NOTE(avik): Overall we want to maintain the invariant that roots that# are phantom symbols are really "phantom," i.e., they cannot be represented# by any input source. This is important when we are deciding derived equalities,# since we can focus our attention exclusively on input sources: deciding# derived equalities involving phantom symbols are, in comparison, trivial.derived_constraint_with_phantom_root.root=symbols[phantom_root_name][0]fordynamic_dimsinsymbols.values():constraints.extend(dynamic_dims)returnconstraintsdef_get_dim_name_mapping(dynamic_shapes:Union[Dict[str,Any],Tuple[Any],List[Any],None]):name_to_dim={}fordimintree_flatten(dynamic_shapes,is_leaf=lambdax:isinstance(x,_Dim),)[0]:ifdimisNone:# NOTE: this must denote a non-Tensor or automatic at this point.continueifisinstance(dim,int):continueelifisinstance(dim,_Dim):name_to_dim[dim.__name__]=dimifisinstance(dim,_DerivedDim):name_to_dim[dim.root.__name__]=dim.root# type: ignore[attr-defined]else:assertisinstance(dim,_DimHint)returnname_to_dim
[docs]defrefine_dynamic_shapes_from_suggested_fixes(msg:str,dynamic_shapes:Union[Dict[str,Any],Tuple[Any],List[Any]],)->Union[Dict[str,Any],Tuple[Any],List[Any]]:""" For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes. Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes. For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range, or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such. e.g. Suggested fixes: dim = Dim('dim', min=3, max=6) -> this just refines the dim's range dim = 4 -> this specializes to a constant dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation However, suggested fixes associated with derived dims can be more complicated. For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root. e.g. dx = Dim('dx') dy = dx + 2 dynamic_shapes = {"x": (dx,), "y": (dy,)} Suggested fixes: dx = 4 # specialization will lead to dy also specializing = 6 dx = Dim('dx', max=6) # dy now has max = 8 Derived dims suggested fixes can also be used to express divisibility constraints. This involves creating new root dims that aren't tied to a particular input shape. In this case the root dims won't appear directly in the new spec, but as a root of one of the dims. e.g. Suggested fixes: _dx = Dim('_dx', max=1024) # this won't appear in the return result, but dx will dx = 4*_dx # dx is now divisible by 4, with a max value of 4096 """importreimportsympyfromtorch._dynamo.excimportUserError,UserErrorTypefromtorch.fx.experimental.symbolic_shapesimport_is_supported_equivalencetry:shape_fixes_msg=msg.split("Suggested fixes:")[1].strip()exceptExceptionasexc:raiseUserError(UserErrorType.INVALID_INPUT,"Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()",)fromexc# build shape_fixes dictionaryshape_fixes={}forfixinshape_fixes_msg.split("\n"):fix=fix.strip()ifmatch:=re.match(r"(.*) = Dim\('(.*)'.*\)",fix):name=match.group(1)_min,_max=None,Noneifmatch_min:=re.match(r".* = Dim\('.*', min\=([0-9]+).*\)",fix):_min=int(match_min.group(1))ifmatch_max:=re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)",fix):_max=int(match_max.group(1))shape_fixes[name]=Dim(name,min=_min,max=_max)else:name,expr=fix.split(" = ")expr=sympy.sympify(expr)ifisinstance(expr,sympy.Number):# static, integershape_fixes[name]=int(expr)else:# relation or derived dimshape_fixes[name]=exprname_to_dim=_get_dim_name_mapping(dynamic_shapes)# track derived dim rootsroots:Set[str]=set()fork,cinshape_fixes.items():assertisinstance(c,(int,_Dim,_DerivedDim,sympy.Expr))ifisinstance(c,sympy.Expr):# check dim/derived dim expressionassert_is_supported_equivalence(c)shape_fixes[k]=croots.add(str(next(iter(c.free_symbols))))ifisinstance(c,_DerivedDim):roots.add(c.root.__name__)# type: ignore[attr-defined]# check keys are existing dims or new rootsfork,cinshape_fixes.items():assertkinname_to_dimorkinroots# cache so we don't produce multiple derived dim objectsderived_dim_cache:Dict[str,_DerivedDim]={}defapply_fixes(path,dim,dummy):ifdimisNoneorisinstance(dim,int):# not dynamicreturndimelifdim.__name__inshape_fixes:# directly fixfix=shape_fixes[dim.__name__]ifisinstance(fix,sympy.Expr):# now derived or relatedifstr(fix)inderived_dim_cache:returnderived_dim_cache[str(fix)]else:symbol=next(iter(fix.free_symbols))# try to locate symbolifsymbol.nameinshape_fixes:root=shape_fixes[symbol.name]else:assertsymbol.nameinname_to_dimroot=name_to_dim[symbol.name]# figure out value of fixmodulus,remainder=sympy.polys.polytools.div(fix,symbol)dim=rootifmodulus!=1:dim=int(modulus)*dimifremainder!=0:dim=dim+int(remainder)derived_dim_cache[str(fix)]=dimreturndimelse:returnfixelifisinstance(dim,_DerivedDim)anddim.root.__name__inshape_fixes:# type: ignore[attr-defined]ifdim.__name__inderived_dim_cache:returnderived_dim_cache[dim.__name__]else:# evaluate new derived value based on root_dim=dim.fn(shape_fixes[dim.root.__name__])# type: ignore[attr-defined]derived_dim_cache[dim.__name__]=_dimreturn_dimreturndim# unchanged dimreturn_tree_map_with_path(apply_fixes,dynamic_shapes,dynamic_shapes)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.