ShapeEnv¶
- class torch.fx.experimental.symbolic_shapes.ShapeEnv(*, should_record_events=None, tracked_fakes=None, **kwargs)[source][source]¶
- bind_symbols(placeholders, args)[source][source]¶
Given a paired list of placeholders (fake tensors with symbolic sizes) and concrete arguments (regular tensors with real sizes), returns a dictionary mapping each symbol to its real value. So for example, if you have a placeholder with size (s0, s1), binding (2, 4) to it will give you {s0: 2, s1: 4}. This is not guaranteed to bind ALL symbols in the ShapeEnv; we can’t bind a symbol if it doesn’t occur in any placeholder, and symbols that already have replacements won’t get bindings.
This is a little duplicative with evaluate_guards but it’s different enough that it seemed cleanest to make another copy. This assumes the guards are already checked, though if it’s cheap we’ll check for shenanigans
- Return type
Dict[sympy.Symbol, int]
- bound_sympy(expr, size_oblivious=False)[source][source]¶
Given a sympy expression, computes a ValueRanges bound for what values it can be
- Return type
ValueRanges
- cleanup()[source][source]¶
Break reference cycles.
This destroys the stacks. If you really want to keep them, we just need some way to break references on code objects.
- create_symbol(val, source, dynamic_dim=DimDynamic.DUCK, constraint_dim=None, positive=True, do_not_specialize_zero_one=False, symbolic_context=None)[source][source]¶
Create a new symbol which is tracked by this ShapeEnv
- Return type
Expr
- create_symbolic_sizes_strides_storage_offset(ex, source, *, symbolic_context=None)[source][source]¶
Returns a list of symbolic sizes and strides for the given tensor. We try our best to express stride in terms of the sizes, so as to not introduce new symbolic variables.
- create_symboolnode(sym)[source][source]¶
Create a SymBool object from a sympy boolean expression
- Return type
- create_symfloatnode(sym, *, hint, source=None)[source][source]¶
Create a SymFloat value from a symbolic expression
- create_symintnode(sym, *, hint, source=None)[source][source]¶
Create a SymInt value from a symbolic expression
If you know what the current hint value of the SymInt to be created is, pass it into hint. Otherwise, pass None and we will make our best guess
- create_unbacked_symbool()[source][source]¶
Create a symbolic boolean without a hint value
- Return type
- create_unbacked_symfloat()[source][source]¶
Create a symbolic float without a hint value
- Return type
- create_unbacked_symint(source=None)[source][source]¶
Create a symbolic integer without a hint value
- Return type
- create_unspecified_symbol(val, source, dynamic_dim=DimDynamic.DUCK, constraint_dim=None, symbolic_context=None)[source][source]¶
Create a symbol with an unspecified value
Compared to standard symbols we do not assume the value is positive, nor do we specialze on zero or one values.
- Return type
Expr
- create_unspecified_symint_and_symbol(value, source, dynamic_dim)[source][source]¶
Create a SymInt wrapping a new unspecified symbol
- defer_runtime_assert(orig_expr, msg, fx_node=None)[source][source]¶
Create an assert that is checked at runtime
- Parameters
orig_expr (sympy.Expr) – Boolean expression to assert is true
msg (str) – Message to display on assertion failure
fx_node (Optional, torch.fx.Node) – node in
self.graph
corresponding to the expression, if applicable
- Return type
- evaluate_guards_expression(code, args)[source][source]¶
Expected to be used with produce_guards_expression(). Evaluates an expression generated by produce_guards_expression for the given concrete args.
- Return type
- evaluate_guards_for_args(placeholders, args, *, ignore_static=True)[source][source]¶
Generate guards for a graph’s placeholder values and evaluate the guards with args
- Return type
- format_guards(verbose=False)[source][source]¶
Format this shape env’s guard expressions with optional traceback info if verbose
- Return type
- freeze()[source][source]¶
Freeze this ShapeEnv to stop accumulating guards
A frozen ShapeEnv will ignore any further guards generated on it and only emit a warning which may lead to accuracy problems.
- freeze_runtime_asserts()[source][source]¶
Freeze this ShapeEnv to stop adding deferred runtime asserts.
We will error if you try to install a new runtime assert when it is frozen. This would indicate a lowering violation, or perhaps something we know statically is already True but we are checking it again in a way that is not clearly dischargeable.
- get_axioms(symbols=None, compute_hint=False)[source][source]¶
Given the symbols in an expression, it returns all the runtime asserts that have those symbols concatenated with all the guards. If symbols is None, it returns all the runtime asserts (and all the guards)
- Return type
Tuple[Boolean, …]
- get_implications(e)[source][source]¶
Given a expression, it returns a list of predicates that follow from it
- get_nontrivial_guards()[source][source]¶
Returns a list of guard expressions that aren’t statically known (i.e. not trivial)
- Return type
List[Boolean]
- get_pruned_guards(symints)[source][source]¶
Get a list of guards, but pruned so it only provides guards that reference symints from the passed in input
- Return type
List[ShapeGuard]
- ignore_fresh_unbacked_symbols()[source][source]¶
Indicates that the newly allocated unbacked SymInts are being discarded
- Return type
Iterator[None]
- is_unbacked_symint(symbol)[source][source]¶
Check if a sympy symbol matches the naming convention for unbacked symbols
- Return type
- produce_guards(*args, **kwargs)[source][source]¶
Like produce_guards_verbose, but only returns the non-verbose guard expressions (no verbose guards produced.)
- produce_guards_expression(placeholders, *, guards=None, ignore_static=True)[source][source]¶
Expected to be used with evaluate_guards_expression(). Produces the guards for the given placeholders and returns a string expression to be evaluated by evaluate_guards_expression given concrete values for the placeholders.
- Return type
Optional[str]
- produce_guards_verbose(placeholders, sources, source_ref=<function ShapeEnv.<lambda>>, *, guards=None, input_contexts=None, equalities_inputs=None, _simplified=False, ignore_static=True)[source][source]¶
Generates a list of guards strings which, when evaluated in a context that defines tensors for all the sources, returns True or False depending on if the guards in the list evaluated to True or not. Primarily used by Dynamo, but this is also helpful for manual testing of guards (see evaluate_guards_for_args)
For convenience in testing, a source is allowed to be a str, in which case we will assume it is a LocalSource
simplified lets you omit duck sizing, equality and 0/1 guards. This is useful for testing when you don’t care about the boilerplate guards, and it may be helpful for user output too (be careful though; some equality guards are nontrivial! It would be nice to get simplified output to print them too). It’s private because it’s not intended for normal use
- replace(expr)[source][source]¶
Apply symbol replacements to any symbols in the given expression
- Return type
_SympyT
- set_unbacked_var_to_val(k, v)[source][source]¶
Used only when propagate_real_tensors; registers a value for an unbacked symbol, which can be used last resort to resolve hints.
- simplify(expr)[source][source]¶
Use known constraints and replacements to simplify the given expr
- Return type
_SympyT
- size_hint(expr, *, allow_none=False)[source][source]¶
Gets a size hint for a given expression from the underlying shapes we had. Does not introduce a guard, so only use this when you can guarantee that your code is still valid for arbitrary shapes (such as optimization decisions)
- Return type
Optional[Basic]