functorch.vmap¶
-
functorch.
vmap
(func, in_dims=0, out_dims=0)[source]¶ vmap is the vectorizing map;
vmap(func)
returns a new function that mapsfunc
over some dimension of the inputs. Semantically, vmap pushes the map into PyTorch operations called byfunc
, effectively vectorizing those operations.vmap is useful for handling batch dimensions: one can write a function
func
that runs on examples and then lift it to a function that can take batches of examples withvmap(func)
. vmap can also be used to compute batched gradients when composed with autograd.- Parameters
func (function) – A Python function that takes one or more arguments. Must return one or more Tensors.
in_dims (int or nested structure) – Specifies which dimension of the inputs should be mapped over.
in_dims
should have a structure like the inputs. If thein_dim
for a particular input is None, then that indicates there is no map dimension. Default: 0.out_dims (int or Tuple[int]) – Specifies where the mapped dimension should appear in the outputs. If
out_dims
is a Tuple, then it should have one element per output. Default: 0.
- Returns
Returns a new “batched” function. It takes the same inputs as
func
, except each input has an extra dimension at the index specified byin_dims
. It takes returns the same outputs asfunc
, except each output has an extra dimension at the index specified byout_dims
.
One example of using
vmap()
is to compute batched dot products. PyTorch doesn’t provide a batchedtorch.dot
API; instead of unsuccessfully rummaging through docs, usevmap()
to construct a new function.>>> torch.dot # [D], [D] -> [] >>> batched_dot = functorch.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y)
vmap()
can be helpful in hiding batch dimensions, leading to a simpler model authoring experience.>>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = functorch.vmap(model)(examples)
vmap()
can also help vectorize computations that were previously difficult or impossible to batch. One example is higher-order gradient computation. The PyTorch autograd engine computes vjps (vector-Jacobian products). Computing a full Jacobian matrix for some function f: R^N -> R^N usually requires N calls toautograd.grad
, one per Jacobian row. Usingvmap()
, we can vectorize the whole computation, computing the Jacobian in a single call toautograd.grad
.>>> # Setup >>> N = 5 >>> f = lambda x: x ** 2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = functorch.vmap(get_vjp)(I_N)
vmap()
can also be nested, producing an output with multiple batched dimensions>>> torch.dot # [D], [D] -> [] >>> batched_dot = functorch.vmap(functorch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3]
If the inputs are not batched along the first dimension,
in_dims
specifies the dimension that each inputs are batched along as>>> torch.dot # [N], [N] -> [] >>> batched_dot = functorch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
If there are multiple inputs each of which is batched along different dimensions,
in_dims
must be a tuple with the batch dimension for each input as>>> torch.dot # [D], [D] -> [] >>> batched_dot = functorch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
If the input is a Python struct,
in_dims
must be a tuple containing a struct matching the shape of the input:>>> f = lambda dict: torch.dot(dict['x'], dict['y']) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {'x': x, 'y': y} >>> batched_dot = functorch.vmap(f, in_dims=({'x': 0, 'y': None},)) >>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched along any dimension by using
out_dims
>>> f = lambda x: x ** 2 >>> x = torch.randn(2, 5) >>> batched_pow = functorch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will accept kwargs
>>> x = torch.randn([2, 5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> batched_pow = functorch.vmap(f) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
Note
vmap does not provide general autobatching or handle variable-length sequences out of the box.