# Optimize memory utilization using `while_loop` ## `while_loop` `while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by [torch.\_higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66). PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`. ### Usage: ``` python import torch_xla.experimental.fori_loop from torch._higher_order_ops.while_loop import while_loop result = while_loop(cond_fn, body_fn, init) ``` - `cond_fn`: User-defined condition function. - `body_fn`: User-defined loop body function. - `init`: Initial values (tuple or list). ### simple example with `while_loop`: ``` bash # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla >>> import torch_xla.experimental.fori_loop >>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm >>> >>> device = xm.xla_device() >>> >>> def cond_fn(iteri, x): ... return iteri > 0 ... >>> def body_fn(iteri, x): ... return iteri - 1, torch.add(x, 1) ... >>> init_val = torch.tensor(3, device=device) >>> iteri = torch.tensor(10, device=device) >>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val)) >>> res FunctionalTensor(lvl=0, value=\ tensor(13, device='xla:0')) ``` #### Control group test case For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times: ## Control group example with pure python `while` loop ``` python # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla >>> import torch_xla.core.xla_model as xm >>> >>> device = xm.xla_device() >>> >>> init_val = torch.tensor(1, device=device) >>> iteri = torch.tensor(50, device=device) >>> >>> while iteri > 0: ... init_val = init_val + 1 ... iteri -= 1 ... >>> init_val tensor(51, device='xla:0') ``` PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape