torch.cond ============== cond_branch_class_method ^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond class MySubModule(torch.nn.Module): def foo(self, x): return x.cos() def forward(self, x): return self.foo(x) class CondBranchClassMethod(torch.nn.Module): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using class method in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def __init__(self): super().__init__() self.subm = MySubModule() def bar(self, x): return x.sin() def forward(self, x): return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [l_x_]); true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[3]" = conditional[0]; conditional = None return (getitem,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[3]"): cos: "f32[3]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None return (cos,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[3]"): sin: "f32[3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_branch_nested_function ^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_branch_nested_function(x): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using nested function in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def true_fn(x): def inner_true_fn(y): return x + y return inner_true_fn(x) def false_fn(x): def inner_false_fn(y): return x - y return inner_false_fn(x) return cond(x.shape[0] < 10, true_fn, false_fn, [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [l_x_]); true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[3]" = conditional[0]; conditional = None return (getitem,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[3]"): add: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None return (add,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[3]"): sub: "f32[3]" = torch.ops.aten.sub.Tensor(arg0_1, arg0_1); arg0_1 = None return (sub,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_branch_nonlocal_variables ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_branch_nonlocal_variables(x): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. The code below will not work because capturing closure variables is not supported. ``` my_tensor_var = x + 100 my_primitive_var = 3.14 def true_fn(y): nonlocal my_tensor_var, my_primitive_var return y + my_tensor_var + my_primitive_var def false_fn(y): nonlocal my_tensor_var, my_primitive_var return y - my_tensor_var - my_primitive_var return cond(x.shape[0] > 5, true_fn, false_fn, [x]) ``` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ my_tensor_var = x + 100 my_primitive_var = 3.14 def true_fn(x, y, z): return x + y + z def false_fn(x, y, z): return x - y - z return cond( x.shape[0] > 5, true_fn, false_fn, [x, my_tensor_var, torch.tensor(my_primitive_var)], ) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, _lifted_tensor_constant0: "f32[]", l_x_: "f32[6]"): add: "f32[6]" = torch.ops.aten.add.Tensor(l_x_, 100) lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0); _lifted_tensor_constant0 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [l_x_, add, lift_fresh_copy]); true_graph_0 = false_graph_0 = l_x_ = add = lift_fresh_copy = None getitem: "f32[6]" = conditional[0]; conditional = None return (getitem,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"): add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None add_1: "f32[6]" = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None return (add_1,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"): sub: "f32[6]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, arg2_1); sub = arg2_1 = None return (sub_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_closed_over_variable ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.cond <torch.cond>`, :doc:`python.closure <python.closure>` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond class CondClosedOverVariable(torch.nn.Module): """ torch.cond() supports branches closed over arbitrary variables. """ def forward(self, pred, x): def true_fn(val): return x * 2 def false_fn(val): return x - 2 return cond(pred, true_fn, false_fn, [x + 1]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_pred_: "b8[]", l_x_: "f32[3, 2]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(l_pred_, true_graph_0, false_graph_0, [l_x_]); l_pred_ = true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[3, 2]" = conditional[0]; conditional = None return (getitem,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None return (mul,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[3, 2]"): sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, 2); arg0_1 = None return (sub,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_pred_'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_operands ^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch.export import Dim from functorch.experimental.control_flow import cond x = torch.randn(3, 2) y = torch.ones(2) dim0_x = Dim("dim0_x") def cond_operands(x, y): """ The operands passed to cond() must be: - a list of tensors - match arguments of `true_fn` and `false_fn` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def true_fn(x, y): return x + y def false_fn(x, y): return x - y return cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[s0, 2]", l_y_: "f32[2]"): sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(l_x_, 0) gt: "Sym(s0 > 2)" = sym_size_int > 2; sym_size_int = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [l_x_, l_y_]); gt = true_graph_0 = false_graph_0 = l_x_ = l_y_ = None getitem: "f32[s0, 2]" = conditional[0]; conditional = None return (getitem,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"): add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return (add,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"): sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return (sub,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_y_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)} Equality constraints: [] cond_predicate ^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape <torch.dynamic-shape>`, :doc:`torch.cond <torch.cond>` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_predicate(x): """ The conditional statement (aka predicate) passed to cond() must be one of the following: - torch.Tensor with a single element - boolean expression NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ pred = x.dim() > 2 and x.shape[2] > 10 return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[6, 4, 3]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [l_x_]); true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[6, 4, 3]" = conditional[0]; conditional = None return (getitem,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[6, 4, 3]"): cos: "f32[6, 4, 3]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None return (cos,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: "f32[6, 4, 3]"): sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: []