ShapeEnv¶
- class torch.fx.experimental.symbolic_shapes.ShapeEnv(*, should_record_events=None, tracked_fakes=None, **kwargs)[source]¶
- bind_symbols(placeholders, args)[source]¶
Given a paired list of placeholders (fake tensors with symbolic sizes) and concrete arguments (regular tensors with real sizes), returns a dictionary mapping each symbol to its real value. So for example, if you have a placeholder with size (s0, s1), binding (2, 4) to it will give you {s0: 2, s1: 4}. This is not guaranteed to bind ALL symbols in the ShapeEnv; we can’t bind a symbol if it doesn’t occur in any placeholder, and symbols that already have replacements won’t get bindings.
This is a little duplicative with evaluate_guards but it’s different enough that it seemed cleanest to make another copy. This assumes the guards are already checked, though if it’s cheap we’ll check for shenanigans
- bound_sympy(expr, size_oblivious=False)[source]¶
Given a sympy expression, computes a ValueRanges bound for what values it can be
- Return type
ValueRanges
- cleanup()[source]¶
Break reference cycles.
This destroys the stacks. If you really want to keep them, we just need some way to break references on code objects.
- create_symbol(val, source, dynamic_dim=DimDynamic.DUCK, constraint_dim=None, positive=True, do_not_specialize_zero_one=False, symbolic_context=None)[source]¶
Create a new symbol which is tracked by this ShapeEnv
- Return type
Expr
- create_symbolic_sizes_strides_storage_offset(ex, source, *, symbolic_context=None)[source]¶
Returns a list of symbolic sizes and strides for the given tensor. We try our best to express stride in terms of the sizes, so as to not introduce new symbolic variables.
- create_symfloatnode(sym, *, hint, source=None)[source]¶
Create a SymFloat value from a symbolic expression
- create_symintnode(sym, *, hint, source=None)[source]¶
Create a SymInt value from a symbolic expression
If you know what the current hint value of the SymInt to be created is, pass it into hint. Otherwise, pass None and we will make our best guess
- create_unspecified_symbol(val, source, dynamic_dim=DimDynamic.DUCK, constraint_dim=None)[source]¶
Create a symbol with an unspecified value
Compared to standard symbols we do not assume the value is positive, nor do we specialze on zero or one values.
- Return type
Expr
- create_unspecified_symint_and_symbol(value, source, dynamic_dim)[source]¶
Create a SymInt wrapping a new unspecified symbol
- defer_runtime_assert(orig_expr, msg, fx_node=None)[source]¶
Create an assert that is checked at runtime
- Parameters
orig_expr (sympy.Expr) – Boolean expression to assert is true
msg (str) – Message to display on assertion failure
fx_node (Optional, torch.fx.Node) – node in
self.graph
corresponding to the expression, if applicable
- evaluate_guards_expression(code, args)[source]¶
Expected to be used with produce_guards_expression(). Evaluates an expression generated by produce_guards_expression for the given concrete args.
- evaluate_guards_for_args(placeholders, args, *, ignore_static=True)[source]¶
Generate guards for a graph’s placeholder values and evaluate the guards with args
- format_guards(verbose=False)[source]¶
Format this shape env’s guard expressions with optional traceback info if verbose
- freeze()[source]¶
Freeze this ShapeEnv to stop accumulating guards
A frozen ShapeEnv will ignore any further guards generated on it and only emit a warning which may lead to accuracy problems.
- freeze_runtime_asserts()[source]¶
Freeze this ShapeEnv to stop adding deferred runtime asserts.
We will error if you try to install a new runtime assert when it is frozen. This would indicate a lowering violation, or perhaps something we know statically is already True but we are checking it again in a way that is not clearly dischargeable.
- get_axioms(symbols=None, compute_hint=False)[source]¶
Given the symbols in an expression, it returns all the runtime asserts that have those symbols concatenated with all the guards. If symbols is None, it returns all the runtime asserts (and all the guards)
- Return type
Tuple[Expr]
- get_implications(e)[source]¶
Given a expression, it returns a list of predicates that follow from it
- get_nontrivial_guards()[source]¶
Returns a list of guard expressions that aren’t statically known (i.e. not trivial)
- get_pruned_guards(symints)[source]¶
Get a list of guards, but pruned so it only provides guards that reference symints from the passed in input
- ignore_fresh_unbacked_symbols()[source]¶
Indicates that the newly allocated unbacked SymInts are being discarded
- is_unbacked_symint(symbol)[source]¶
Check if a sympy symbol matches the naming convention for unbacked symbols
- Return type
- produce_guards(placeholders, sources, source_ref=<function ShapeEnv.<lambda>>, *, guards=None, input_contexts=None, equalities_inputs=None, _simplified=False, ignore_static=True)[source]¶
Generates a list of guards strings which, when evaluated in a context that defines tensors for all the sources, returns True or False depending on if the guards in the list evaluated to True or not. Primarily used by Dynamo, but this is also helpful for manual testing of guards (see evaluate_guards_for_args)
For convenience in testing, a source is allowed to be a str, in which case we will assume it is a LocalSource
simplified lets you omit duck sizing, equality and 0/1 guards. This is useful for testing when you don’t care about the boilerplate guards, and it may be helpful for user output too (be careful though; some equality guards are nontrivial! It would be nice to get simplified output to print them too). It’s private because it’s not intended for normal use
- produce_guards_expression(placeholders, *, guards=None, ignore_static=True)[source]¶
Expected to be used with evaluate_guards_expression(). Produces the guards for the given placeholders and returns a string expression to be evaluated by evaluate_guards_expression given concrete values for the placeholders.
- replace(expr)[source]¶
Apply symbol replacements to any symbols in the given expression
- Return type
Expr
- set_unbacked_var_to_val(k, v)[source]¶
Used only when propagate_real_tensors; registers a value for an unbacked symbol, which can be used last resort to resolve hints.
- simplify(expr)[source]¶
Use known constraints and replacements to simplify the given expr
- Return type
Expr