Control Flow - Cond¶
torch.cond is a structured control flow operator. It can be used to specify if-else like control flow and can logically be seen as implemented as follows.
def cond(
pred: Union[bool, torch.Tensor],
true_fn: Callable,
false_fn: Callable,
operands: Tuple[torch.Tensor]
):
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
Its unique power lies in its ability of expressing data-dependent control flow: it lowers to a conditional operator (torch.ops.higher_order.cond), which preserves predicate, true function and false functions. This unlocks great flexibilty in writing and deploying models that change model architecture based on the value or shape of inputs or intermediate outputs of tensor operations.
Warning
torch.cond is a prototype feature in PyTorch. It has limited support for input and output types and doesn’t support training currently. Please look forward to a more stable implementation in a future version of PyTorch. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Examples¶
Below is an example that uses cond to branch based on input shape:
import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
We can eagerly run the model and expect the results vary based on input shape:
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
We can export the model for further transformations and deployment:
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
This gives us an exported program as shown below:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
Notice that torch.cond is lowered to torch.ops.higher_order.cond, its predicate becomes a Symbolic expression over the shape of input, and branch functions becomes two sub-graph attributes of the top level graph module.
Here is another exmaple that showcases how to express a data-dependet control flow:
class DataDependentCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on data dependent predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
The exported program we get after export:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
Invariants of torch.ops.higher_order.cond¶
There are several useful invariants for torch.ops.higher_order.cond:
- For predicate:
Dynamicness of predicate is preserved (e.g. gt shown in the above example)
If the predicate in user-program is constant (e.g. a python bool constant), the pred of the operator will be a constant.
- For branches:
The input and output signature will be a flattened tuple.
They are torch.fx.GraphModule.
Closures in original function becomes explicit inputs. No closures.
No mutations on inputs or globals are allowed.
- For operands:
It will also be a flat tuple.
Nesting of torch.cond in user program becomes nested graph modules.
API Reference¶
- torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands)¶
Conditionally applies true_fn or false_fn.
Warning
torch.cond is a prototype feature in PyTorch. It has limited support for input and output types and doesn’t support training currently. Please look forward to a more stable implementation in a future version of PyTorch. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
cond is structured control flow operator. That is, it is like a Python if-statement, but has restrictions on true_fn, false_fn, and operands that enable it to be capturable using torch.compile and torch.export.
Assuming the constraints on cond’s arguments are met, cond is equivalent to the following:
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)
- Parameters
pred (Union[bool, torch.Tensor]) – A boolean expression or a tensor with one element, indicating which branch function to apply.
true_fn (Callable) – A callable function (a -> b) that is within the scope that is being traced.
false_fn (Callable) – A callable function (a -> b) that is within the scope that is being traced. The true branch and false branch must have consistent input and outputs, meaning the inputs have to be the same, and the outputs have to be the same type and shape.
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – A tuple of inputs to the true/false functions.
Example:
def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
- Restrictions:
The conditional statement (aka pred) must meet one of the following constraints:
It’s a torch.Tensor with only one element, and torch.bool dtype
It’s a boolean expression, e.g. x.shape[0] > 10 or x.dim() > 1 and x.shape[1] > 10
The branch function (aka true_fn/false_fn) must meet all of the following constraints:
The function signature must match with operands.
The function must return a tensor with the same metadata, e.g. shape, dtype, etc.
The function cannot have in-place mutations on inputs or global variables. (Note: in-place tensor operations such as add_ for intermediate results are allowed in a branch)
Warning
Temporal Limitations:
cond only supports inference right now. Autograd will be supported in the future.
The output of branches must be a single Tensor. Pytree of tensors will be supported in the future.