• Docs >
  • Optimize memory utilization using while_loop
Shortcuts

Optimize memory utilization using while_loop

while_loop

while_loop replace pure python while loop, PyTorch supported while_loop by torch._higher_order_ops.while_loop. PyTorch/XLA provide experimental XLA backend support for torch._higher_order_ops.while_loop via XLA::While.

Usage:

import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
result = while_loop(cond_fn, body_fn, init)
  • cond_fn: User-defined condition function.

  • body_fn: User-defined loop body function.

  • init: Initial values (tuple or list).

simple example with while_loop:

# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(iteri, x):
...   return iteri > 0
...
>>> def body_fn(iteri, x):
...   return iteri - 1, torch.add(x, 1)
...
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor(13, device='xla:0'))

Control group test case

For better compare difference between pure python while loop and while_loop, there is one test case called pure python while loop with similar logic: cumulative plus 1 for ten times:

Control group example with pure python while loop

# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> while iteri > 0:
...   init_val = init_val + 1
...   iteri -= 1
...
>>> init_val
tensor(51, device='xla:0')

PyTorch/XLA would include while_loop support in 2.4 with test case, support for fori_loop would be added after 2.4. For while_loop, currently we only should force define body_fn with same input and output(return args) shape

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources