Optimize memory utilization using while_loop
¶
while_loop
¶
while_loop
replace pure python while
loop, PyTorch supported
while_loop
by
torch._higher_order_ops.while_loop.
PyTorch/XLA provide experimental XLA backend support for
torch._higher_order_ops.while_loop
via XLA::While
.
Usage:¶
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
result = while_loop(cond_fn, body_fn, init)
cond_fn
: User-defined condition function.body_fn
: User-defined loop body function.init
: Initial values (tuple or list).
simple example with while_loop
:¶
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(iteri, x):
... return iteri > 0
...
>>> def body_fn(iteri, x):
... return iteri - 1, torch.add(x, 1)
...
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor(13, device='xla:0'))
Control group test case¶
For better compare difference between pure python while loop
and
while_loop
, there is one test case called pure python while
loop
with similar logic: cumulative plus 1 for ten times:
Control group example with pure python while
loop¶
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> while iteri > 0:
... init_val = init_val + 1
... iteri -= 1
...
>>> init_val
tensor(51, device='xla:0')
PyTorch/XLA would include while_loop
support in 2.4 with test case,
support for fori_loop
would be added after 2.4. For while_loop
,
currently we only should force define body_fn
with same input
and
output(return args)
shape