Shortcuts

ShapeEnv

class torch.fx.experimental.symbolic_shapes.ShapeEnv(*, should_record_events=None, tracked_fakes=None, **kwargs)[source][source]
add_var_to_val(expr, val)[source][source]

Adds a new symbol to the symbolic environment.

bind_symbols(placeholders, args)[source][source]

Given a paired list of placeholders (fake tensors with symbolic sizes) and concrete arguments (regular tensors with real sizes), returns a dictionary mapping each symbol to its real value. So for example, if you have a placeholder with size (s0, s1), binding (2, 4) to it will give you {s0: 2, s1: 4}. This is not guaranteed to bind ALL symbols in the ShapeEnv; we can’t bind a symbol if it doesn’t occur in any placeholder, and symbols that already have replacements won’t get bindings.

This is a little duplicative with evaluate_guards but it’s different enough that it seemed cleanest to make another copy. This assumes the guards are already checked, though if it’s cheap we’ll check for shenanigans

Return type

Dict[sympy.Symbol, int]

bound_sympy(expr, size_oblivious=False)[source][source]

Given a sympy expression, computes a ValueRanges bound for what values it can be

Return type

ValueRanges

check_equal(other)[source][source]

Compare another ShapeEnv for equivalence

cleanup()[source][source]

Break reference cycles.

This destroys the stacks. If you really want to keep them, we just need some way to break references on code objects.

create_symbol(val, source, dynamic_dim=DimDynamic.DUCK, constraint_dim=None, positive=True, do_not_specialize_zero_one=False, symbolic_context=None)[source][source]

Create a new symbol which is tracked by this ShapeEnv

Return type

Expr

create_symbolic_sizes_strides_storage_offset(ex, source, *, symbolic_context=None)[source][source]

Returns a list of symbolic sizes and strides for the given tensor. We try our best to express stride in terms of the sizes, so as to not introduce new symbolic variables.

Return type

Tuple[Tuple[Union[int, SymInt], …], Tuple[Union[int, SymInt], …], Union[int, SymInt]]

create_symboolnode(sym)[source][source]

Create a SymBool object from a sympy boolean expression

Return type

SymBool

create_symfloatnode(sym, *, hint, source=None)[source][source]

Create a SymFloat value from a symbolic expression

Return type

Union[float, SymFloat]

create_symintnode(sym, *, hint, source=None)[source][source]

Create a SymInt value from a symbolic expression

If you know what the current hint value of the SymInt to be created is, pass it into hint. Otherwise, pass None and we will make our best guess

Return type

Union[int, SymInt]

create_unbacked_symbool()[source][source]

Create a symbolic boolean without a hint value

Return type

SymBool

create_unbacked_symfloat()[source][source]

Create a symbolic float without a hint value

Return type

SymFloat

create_unbacked_symint()[source][source]

Create a symbolic integer without a hint value

Return type

SymInt

create_unspecified_symbol(val, source, dynamic_dim=DimDynamic.DUCK, constraint_dim=None, symbolic_context=None)[source][source]

Create a symbol with an unspecified value

Compared to standard symbols we do not assume the value is positive, nor do we specialze on zero or one values.

Return type

Expr

create_unspecified_symint_and_symbol(value, source, dynamic_dim)[source][source]

Create a SymInt wrapping a new unspecified symbol

Return type

Union[int, SymInt]

defer_runtime_assert(orig_expr, msg, fx_node=None)[source][source]

Create an assert that is checked at runtime

Parameters
  • orig_expr (sympy.Expr) – Boolean expression to assert is true

  • msg (str) – Message to display on assertion failure

  • fx_node (Optional, torch.fx.Node) – node in self.graph corresponding to the expression, if applicable

Return type

bool

evaluate_guards_expression(code, args)[source][source]

Expected to be used with produce_guards_expression(). Evaluates an expression generated by produce_guards_expression for the given concrete args.

Return type

bool

evaluate_guards_for_args(placeholders, args, *, ignore_static=True)[source][source]

Generate guards for a graph’s placeholder values and evaluate the guards with args

Return type

bool

evaluate_symexpr(code)[source][source]

To be used by compile_fx to evaluate symexprs

Return type

Union[int, float, bool]

format_guards(verbose=False)[source][source]

Format this shape env’s guard expressions with optional traceback info if verbose

Return type

str

freeze()[source][source]

Freeze this ShapeEnv to stop accumulating guards

A frozen ShapeEnv will ignore any further guards generated on it and only emit a warning which may lead to accuracy problems.

freeze_runtime_asserts()[source][source]

Freeze this ShapeEnv to stop adding deferred runtime asserts.

We will error if you try to install a new runtime assert when it is frozen. This would indicate a lowering violation, or perhaps something we know statically is already True but we are checking it again in a way that is not clearly dischargeable.

get_axioms(symbols=None, compute_hint=False)[source][source]

Given the symbols in an expression, it returns all the runtime asserts that have those symbols concatenated with all the guards. If symbols is None, it returns all the runtime asserts (and all the guards)

Return type

Tuple[Boolean, …]

get_implications(e)[source][source]

Given a expression, it returns a list of predicates that follow from it

Return type

Tuple[Tuple[Boolean, BooleanAtom], …]

get_nontrivial_guards()[source][source]

Returns a list of guard expressions that aren’t statically known (i.e. not trivial)

Return type

List[Boolean]

get_pruned_guards(symints)[source][source]

Get a list of guards, but pruned so it only provides guards that reference symints from the passed in input

Return type

List[ShapeGuard]

ignore_fresh_unbacked_symbols()[source][source]

Indicates that the newly allocated unbacked SymInts are being discarded

Return type

Iterator[None]

is_unbacked_symint(symbol)[source][source]

Check if a sympy symbol matches the naming convention for unbacked symbols

Return type

bool

produce_guards(*args, **kwargs)[source][source]

Like produce_guards_verbose, but only returns the non-verbose guard expressions (no verbose guards produced.)

Return type

List[str]

produce_guards_expression(placeholders, *, guards=None, ignore_static=True)[source][source]

Expected to be used with evaluate_guards_expression(). Produces the guards for the given placeholders and returns a string expression to be evaluated by evaluate_guards_expression given concrete values for the placeholders.

Return type

Optional[str]

produce_guards_verbose(placeholders, sources, source_ref=<function ShapeEnv.<lambda>>, *, guards=None, input_contexts=None, equalities_inputs=None, _simplified=False, ignore_static=True)[source][source]

Generates a list of guards strings which, when evaluated in a context that defines tensors for all the sources, returns True or False depending on if the guards in the list evaluated to True or not. Primarily used by Dynamo, but this is also helpful for manual testing of guards (see evaluate_guards_for_args)

For convenience in testing, a source is allowed to be a str, in which case we will assume it is a LocalSource

simplified lets you omit duck sizing, equality and 0/1 guards. This is useful for testing when you don’t care about the boilerplate guards, and it may be helpful for user output too (be careful though; some equality guards are nontrivial! It would be nice to get simplified output to print them too). It’s private because it’s not intended for normal use

Return type

Tuple[List[str], List[str]]

replace(expr)[source][source]

Apply symbol replacements to any symbols in the given expression

Return type

_SympyT

set_unbacked_var_to_val(k, v)[source][source]

Used only when propagate_real_tensors; registers a value for an unbacked symbol, which can be used last resort to resolve hints.

simplify(expr)[source][source]

Use known constraints and replacements to simplify the given expr

Return type

_SympyT

size_hint(expr, *, allow_none=False)[source][source]

Gets a size hint for a given expression from the underlying shapes we had. Does not introduce a guard, so only use this when you can guarantee that your code is still valid for arbitrary shapes (such as optimization decisions)

Return type

Optional[Basic]

suppress_guards()[source][source]

Context manager to ignore all guards generated inside

Return type

_GeneratorContextManager[None]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources