Shortcuts

Source code for torch_xla.core.functions

import torch
import torch_xla
import torch_xla.core.xla_model as xm


class AllReduce(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, reduce_type, scale, groups):
    ctx.reduce_type = reduce_type
    ctx.scale = scale
    output = xm.all_reduce(reduce_type, input, scale=scale, groups=groups)
    ctx.save_for_backward(input, output)
    return output

  @staticmethod
  def backward(ctx, grad_output):
    input, output = ctx.saved_tensors
    grad = grad_output * ctx.scale if ctx.scale != 1.0 else grad_output
    if ctx.reduce_type == xm.REDUCE_SUM:
      return grad, None, None, None
    if ctx.reduce_type == xm.REDUCE_MUL:
      # MUL is not supported by TPU
      grad_scaler = torch.where(input != 0, output / input,
                                torch.zeros_like(input))
      return grad * grad_scaler, None, None, None
    if ctx.reduce_type == xm.REDUCE_MIN or ctx.reduce_type == xm.REDUCE_MAX:
      return torch.where(input == output, grad,
                         torch.zeros_like(grad)), None, None, None
    raise RuntimeError('Unsupported reduce type: {}'.format(ctx.reduce_type))


[docs]def all_reduce(reduce_type, value, scale=1.0, groups=None): """Performs an inplace reduce operation on the input tensor. This is the same as `xm.all_reduce()` but supports autograd differentiation. Args: reduce_type (string): One of ``REDUCE_SUM``, ``REDUCE_MUL``, ``REDUCE_AND``, ``REDUCE_OR``, ``REDUCE_MIN`` and ``REDUCE_MAX``. value (torch.Tensor): The to perform the all reduce op to. scale (float): A default scaling value to be applied after the reduce. Default: 1.0 groups (list, optional): A list of list, representing the replica groups for the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]` defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. Returns: The reduced value across the selected replicas. """ return AllReduce.apply(value, reduce_type, scale, groups)
class AllGather(torch.autograd.Function): @staticmethod def forward(ctx, input, dim): ctx.dim = dim ctx.ordinal = xm.get_ordinal() ctx.world_size = xm.xrt_world_size() return xm.all_gather(input, dim=dim) @staticmethod def backward(ctx, grad_output): slice_size = grad_output.size(ctx.dim) // ctx.world_size return torch.narrow(grad_output.clone(), ctx.dim, ctx.ordinal * slice_size, slice_size), None
[docs]def all_gather(value, dim=0): """Performs an all-gather operation along a given dimension. This is the same as `xm.all_gather()` but supports autograd differentiation. Args: value (torch.Tensor): The input tensor. dim (int): The gather dimension. Default: 0 Returns: A tensor which has, in the ``dim`` dimension, all the values from the participating replicas. """ return AllGather.apply(value, dim)
def distributed_mm(w, x, split=1): """Performs a matrix multiplication with sharded weight. Args: w (torch.Tensor): The sharded weight, RHS of the matrix multiplication operation. The weight shape is `N x Ko` where `Ko` is the shard dimension size. Each ordinal will have its own copy of the weight. x (torch.Tensor): The input tensor, LHS of the matrix multiplication operation. The input shape is `WG x M` where `WG = Ko * WORLD_SIZE`. split (int): The number of splits for the `M` dimension of `x`. Since there is an `all_gather()` on such dimension, if `M` is big, a split might be required in order to fit device memory. Default: 1 Returns: The result of the distributed matrix multiplication operation. """ ordinal = xm.get_ordinal() # w = N x Ko # WG = Ko * WORLD_SIZE # x = WG x M assert x.size(0) // xm.xrt_world_size() == w.size(1) splits = [] if split != 1: size = x.size(1) assert size % split == 0 split_size = size // split splits = torch.split(x, split_size, dim=1) else: splits.append(x) results = [] for xs in splits: # xg = WG x (M * WORLD_SIZE) xg = all_gather(xs, dim=1) # xgn = Ko x (M * WORLD_SIZE) xgn = torch.narrow(xg, 0, ordinal * w.size(1), w.size(1)) # wxg = N x (M * WORLD_SIZE) wxg = w @ xgn # rwxg = N x (M * WORLD_SIZE) rwxg = all_reduce(xm.REDUCE_SUM, wxg) # wx = N x M wx = torch.narrow(rwxg, 1, ordinal * xs.size(1), xs.size(1)) results.append(wx) return torch.cat(results, dim=1) if len(results) > 1 else results[0] class SyncBatchNorm(torch.nn.Module): def __init__( self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, ): super().__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.weight = torch.nn.Parameter(torch.ones(num_features)) self.bias = torch.nn.Parameter(torch.zeros(num_features)) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) def forward(self, batch: torch.Tensor) -> torch.Tensor: assert 2 <= batch.ndim <= 5 and batch.shape[1] == self.num_features reduce_dims = list(range(batch.ndim)) reduce_dims.pop(1) # channel dim if self.training: local_mean = torch.mean(batch, dim=reduce_dims) local_sqr_mean = torch.mean(batch * batch, dim=reduce_dims) scale = 1.0 / xm.xrt_world_size() mean = AllReduceSumLayer.apply(local_mean) * scale sqr_mean = AllReduceSumLayer.apply(local_sqr_mean) * scale var = sqr_mean - mean.pow(2) self.running_mean = ( 1 - self.momentum) * self.running_mean + self.momentum * mean self.running_var = ( 1 - self.momentum) * self.running_var + self.momentum * var else: mean = self.running_mean var = self.running_var res = torch.empty_like(batch) for c in range(self.num_features): if batch.ndim == 2: res = ((batch - mean) / torch.sqrt(var + self.eps)) * self.weight + self.bias else: res[:, c, ...] = ( (batch[:, c, ...] - mean[c]) / torch.sqrt(var[c] + self.eps)) * self.weight[c] + self.bias[c] return res def extra_repr(self) -> str: return f'{self.num_features}, eps={self.eps}' class AllReduceSumLayer(torch.autograd.Function): @staticmethod def forward(ctx, x): return xm.all_reduce(xm.REDUCE_SUM, x) @staticmethod def backward(ctx, grad_output): return xm.all_reduce(xm.REDUCE_SUM, grad_output)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources