.. _cond:

Control Flow - Cond
====================

`torch.cond` is a structured control flow operator. It can be used to specify if-else like control flow
and can logically be seen as implemented as follows.

.. code-block:: python

    def cond(
        pred: Union[bool, torch.Tensor],
        true_fn: Callable,
        false_fn: Callable,
        operands: Tuple[torch.Tensor]
    ):
        if pred:
            return true_fn(*operands)
        else:
            return false_fn(*operands)

Its unique power lies in its ability of expressing **data-dependent control flow**: it lowers to a conditional
operator (`torch.ops.higher_order.cond`), which preserves predicate, true function and false functions.
This unlocks great flexibility in writing and deploying models that change model architecture based on
the **value** or **shape** of inputs or intermediate outputs of tensor operations.

.. warning::
    `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
    doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
    Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

Examples
~~~~~~~~

Below is an example that uses cond to branch based on input shape:

.. code-block:: python

    import torch

    def true_fn(x: torch.Tensor):
        return x.cos() + x.sin()

    def false_fn(x: torch.Tensor):
        return x.sin()

    class DynamicShapeCondPredicate(torch.nn.Module):
        """
        A basic usage of cond based on dynamic shape predicate.
        """

        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            def true_fn(x: torch.Tensor):
                return x.cos()

            def false_fn(x: torch.Tensor):
                return x.sin()

            return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

    dyn_shape_mod = DynamicShapeCondPredicate()

We can eagerly run the model and expect the results vary based on input shape:

.. code-block:: python

    inp = torch.randn(3)
    inp2 = torch.randn(5)
    assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
    assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))

We can export the model for further transformations and deployment:

.. code-block:: python

    inp = torch.randn(4, 3)
    dim_batch = torch.export.Dim("batch", min=2)
    ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
    print(ep)

This gives us an exported program as shown below:

.. code-block::

    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: Sym(s0 > 4) = sym_size > 4;  sym_size = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
            return (conditional,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
                return add

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin

Notice that `torch.cond` is lowered to `torch.ops.higher_order.cond`, its predicate becomes a Symbolic expression over the shape of input,
and branch functions becomes two sub-graph attributes of the top level graph module.

Here is another example that showcases how to express a data-dependent control flow:

.. code-block:: python

    class DataDependentCondPredicate(torch.nn.Module):
        """
        A basic usage of cond based on data dependent predicate.
        """
        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))

The exported program we get after export:

.. code-block::

    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
            gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0);  sum_1 = None

            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
            return (conditional,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
                return add

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin


Invariants of torch.ops.higher_order.cond
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

There are several useful invariants for `torch.ops.higher_order.cond`:

- For predicate:
    - Dynamicness of predicate is preserved (e.g. `gt` shown in the above example)
    - If the predicate in user-program is constant (e.g. a python bool constant), the `pred` of the operator will be a constant.

- For branches:
    - The input and output signature will be a flattened tuple.
    - They are `torch.fx.GraphModule`.
    - Closures in original function becomes explicit inputs. No closures.
    - No mutations on inputs or globals are allowed.

- For operands:
    - It will also be a flat tuple.

- Nesting of `torch.cond` in user program becomes nested graph modules.


API Reference
-------------
.. autofunction:: torch._higher_order_ops.cond.cond