# Guide for using `scan` and `scan_layers`

This is a guide for using `scan` and `scan_layers` in PyTorch/XLA.

## When should you use this

You should consider using [`scan_layers`][scan_layers] if you have a model with
many homogenous (same shape, same logic) layers, for example LLMs. These models
can be slow to compile. `scan_layers` is a drop-in replacement for a for loop over
homogenous layers, such as a bunch of decoder layers. `scan_layers` traces the
first layer and reuses the compiled result for all subsequent layers, significantly 
reducing the model compile time.

[`scan`][scan] on the other hand is a lower level higher-order-op modeled after
[`jax.lax.scan`][jax-lax-scan]. Its primary purpose is to help implement
`scan_layers` under the hood. However, you may find it useful if you would like
to program some sort of loop logic where the loop itself has a first-class
representation in the compiler (specifically, an XLA `While` op).

## `scan_layers` example

Typically, a transformer model passes the input embedding through a sequence of
homogenous decoder layers like the following:

```python
def run_decoder_layers(self, hidden_states):
  for decoder_layer in self.layers:
    hidden_states = decoder_layer(hidden_states)
  return hidden_states
```

When this function is lowered into an HLO graph, the for loop is unrolled into a
flat list of operations, resulting in long compile times. To reduce compile
times, you can replace the for loop with a call to `scan_layers`, as shown in 
[`decoder_with_scan.py`][decoder_with_scan]:

```python
def run_decoder_layers(self, hidden_states):
  from torch_xla.experimental.scan_layers import scan_layers
  return scan_layers(self.layers, hidden_states)
```

You can train this decoder model by running the following command from the root 
directory of a `pytorch/xla` source checkout.

```sh
python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
```

## `scan` example

[`scan`][scan] takes a combine function and applies that function over the leading
dimension of tensors while carrying along state:

```python
def scan(
    fn: Callable[[Carry, X], tuple[Carry, Y]],
    init: Carry,
    xs: X,
) -> tuple[Carry, Y]:
  ...
```

You can use it to loop over the leading dimension of tensors efficiently. If `xs`
is a single tensor, this function is roughly equal to the following Python code:

```python
def scan(fn, init, xs):
  ys = []
  carry = init
  for i in len(range(xs.size(0))):
    carry, y = fn(carry, xs[i])
    ys.append(y)
  return carry, torch.stack(ys, dim=0)
```

Under the hood, `scan` is implemented much more efficiently by lowering the loop
into an XLA `While` operation. This ensures that only one iteration of the loop
is compiled by XLA.

[`scan_examples.py`][scan_examples] contains some example code showing how to use
`scan`. In that file, `scan_example_cumsum` uses `scan` to implement a cumulative 
sum. `scan_example_pytree` demonstrates how to pass PyTrees to `scan`.

You can run the examples with:

```sh
python3 examples/scan/scan_examples.py
```

The output should look something like the following:

```
Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
        [3.],
        [6.]], device='xla:0')


Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
        [1.5000],
        [2.0000],
        [2.5000],
        [3.0000]], device='xla:0')
```

## Limitations

### AOTAutograd compatibility requirement

The functions/modules passed to `scan` and `scan_layers` must be AOTAutograd
traceable. In particular, as of PyTorch/XLA 2.6, `scan` and `scan_layers` cannot
trace functions with custom Pallas kernels. That means if your decoder uses,
for example flash attention, then it's incompatible with `scan`. We are working on
[supporting this important use case][flash-attn-issue] in nightly and the next
releases.

### AOTAutograd overhead

Because `scan` uses AOTAutograd to figure out the backward pass of the input 
function/module on every iteration, it's easy to become tracing bound compared to 
a for loop implementation. In fact, the  `train_decoder_only_base.py` example runs 
slower under `scan` than with for loop as of PyTorch/XLA 2.6 due to this overhead.
We are working on [improving tracing speed][retracing-issue]. This is less of a 
problem when your model is very large or has many layers, which are the situations 
you would want to use `scan` anyways.

## Compile time experiments

To demonstrate the compile time savings, we'll train a simple decoder with many
layers on a single TPU chip with for loops vs with `scan_layers`.

- Run the for loop implementation:

```sh
❯ python3 examples/train_decoder_only_base.py \
    --hidden-size 256 \
    --num-layers 50 \
    --num-attention-heads 4 \
    --num-key-value-heads 2 \
    --intermediate-size 2048 \
    --num-steps 5 \
    --print-metrics

...

Metric: CompileTime
  TotalSamples: 3
  Accumulator: 02m57s694ms418.595us
  ValueRate: 02s112ms586.097us / second
  Rate: 0.054285 / second
  Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
  99%=01m03s028ms571.841us
```

- Run the `scan_layers` implementation:

```sh
❯ python3 examples/train_decoder_only_base.py \
    scan.decoder_with_scan.DecoderWithScan \
    --hidden-size 256 \
    --num-layers 50 \
    --num-attention-heads 4 \
    --num-key-value-heads 2 \
    --intermediate-size 2048 \
    --num-steps 5 \
    --print-metrics

...

Metric: CompileTime
  TotalSamples: 3
  Accumulator: 29s996ms941.409us
  ValueRate: 02s529ms591.388us / second
  Rate: 0.158152 / second
  Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
  99%=18s995ms301.667us
```

We can see that the maximum compile time dropped from `1m03s` to `19s` by
switching to `scan_layers`.

## References

See https://github.com/pytorch/xla/issues/7253 for the design of `scan` and
`scan_layers` itself.

See the function doc comments of [`scan`][scan] and [`scan_layers`][scan_layers]
for details on how to use them.

<!-- xrefs -->

[scan]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py
[scan_layers]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py
[flash-attn-issue]: https://github.com/pytorch/xla/issues/8633
[retracing-issue]: https://github.com/pytorch/xla/issues/8632
[jax-lax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
[decoder_with_scan]: /examples/scan/decoder_with_scan.py
[scan_examples]: /examples/scan/scan_examples.py