• Docs >
  • Guide for using scan and scan_layers
Shortcuts

Guide for using scan and scan_layers

This is a guide for using scan and scan_layers in PyTorch/XLA.

When should you use this

You should consider using ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_ if you have a model with many homogenous (same shape, same logic) layers, for example LLMs. These models can be slow to compile. scan_layers is a drop-in replacement for a for loop over homogenous layers, such as a bunch of decoder layers. scan_layers traces the first layer and reuses the compiled result for all subsequent layers, significantly reducing the model compile time.

``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ on the other hand is a lower level higher-order-op modeled after ``jax.lax.scan` <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>`_. Its primary purpose is to help implement scan_layers under the hood. However, you may find it useful if you would like to program some sort of loop logic where the loop itself has a first-class representation in the compiler (specifically, an XLA While op).

scan_layers example

Typically, a transformer model passes the input embedding through a sequence of homogenous decoder layers like the following:

def run_decoder_layers(self, hidden_states):
  for decoder_layer in self.layers:
    hidden_states = decoder_layer(hidden_states)
  return hidden_states

When this function is lowered into an HLO graph, the for loop is unrolled into a flat list of operations, resulting in long compile times. To reduce compile times, you can replace the for loop with a call to scan_layers, as shown in ``decoder_with_scan.py` </examples/scan/decoder_with_scan.py>`_:

def run_decoder_layers(self, hidden_states):
  from torch_xla.experimental.scan_layers import scan_layers
  return scan_layers(self.layers, hidden_states)

You can train this decoder model by running the following command from the root directory of a pytorch/xla source checkout.

python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan

scan example

``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ takes a combine function and applies that function over the leading dimension of tensors while carrying along state:

def scan(
    fn: Callable[[Carry, X], tuple[Carry, Y]],
    init: Carry,
    xs: X,
) -> tuple[Carry, Y]:
  ...

You can use it to loop over the leading dimension of tensors efficiently. If xs is a single tensor, this function is roughly equal to the following Python code:

def scan(fn, init, xs):
  ys = []
  carry = init
  for i in len(range(xs.size(0))):
    carry, y = fn(carry, xs[i])
    ys.append(y)
  return carry, torch.stack(ys, dim=0)

Under the hood, scan is implemented much more efficiently by lowering the loop into an XLA While operation. This ensures that only one iteration of the loop is compiled by XLA.

``scan_examples.py` </examples/scan/scan_examples.py>`_ contains some example code showing how to use scan. In that file, scan_example_cumsum uses scan to implement a cumulative sum. scan_example_pytree demonstrates how to pass PyTrees to scan.

You can run the examples with:

python3 examples/scan/scan_examples.py

The output should look something like the following:

Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
        [3.],
        [6.]], device='xla:0')


Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
        [1.5000],
        [2.0000],
        [2.5000],
        [3.0000]], device='xla:0')

Limitations

AOTAutograd compatibility requirement

The functions/modules passed to scan and scan_layers must be AOTAutograd traceable. In particular, as of PyTorch/XLA 2.6, scan and scan_layers cannot trace functions with custom Pallas kernels. That means if your decoder uses, for example flash attention, then it’s incompatible with scan. We are working on supporting this important use case in nightly and the next releases.

AOTAutograd overhead

Because scan uses AOTAutograd to figure out the backward pass of the input function/module on every iteration, it’s easy to become tracing bound compared to a for loop implementation. In fact, the train_decoder_only_base.py example runs slower under scan than with for loop as of PyTorch/XLA 2.6 due to this overhead. We are working on improving tracing speed. This is less of a problem when your model is very large or has many layers, which are the situations you would want to use scan anyways.

Compile time experiments

To demonstrate the compile time savings, we’ll train a simple decoder with many layers on a single TPU chip with for loops vs with scan_layers.

  • Run the for loop implementation:

 python3 examples/train_decoder_only_base.py \
    --hidden-size 256 \
    --num-layers 50 \
    --num-attention-heads 4 \
    --num-key-value-heads 2 \
    --intermediate-size 2048 \
    --num-steps 5 \
    --print-metrics

...

Metric: CompileTime
  TotalSamples: 3
  Accumulator: 02m57s694ms418.595us
  ValueRate: 02s112ms586.097us / second
  Rate: 0.054285 / second
  Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
  99%=01m03s028ms571.841us
  • Run the scan_layers implementation:

 python3 examples/train_decoder_only_base.py \
    scan.decoder_with_scan.DecoderWithScan \
    --hidden-size 256 \
    --num-layers 50 \
    --num-attention-heads 4 \
    --num-key-value-heads 2 \
    --intermediate-size 2048 \
    --num-steps 5 \
    --print-metrics

...

Metric: CompileTime
  TotalSamples: 3
  Accumulator: 29s996ms941.409us
  ValueRate: 02s529ms591.388us / second
  Rate: 0.158152 / second
  Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
  99%=18s995ms301.667us

We can see that the maximum compile time dropped from 1m03s to 19s by switching to scan_layers.

References

See https://github.com/pytorch/xla/issues/7253 for the design of scan and scan_layers itself.

See the function doc comments of ``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ and ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_ for details on how to use them.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources