- static Function.vmap(info, in_dims, *args)¶
Defines a rule for the behavior of this autograd.Function underneath
torch.vmap(). For a
torch.vmap(), you must either override this staticmethod, or set
True(you may not do both).
If you choose to override this staticmethod: it must accept
infoobject as the first argument.
info.batch_sizespecifies the size of the dimension being vmapped over, while
info.randomnessis the randomness option passed to
in_dimstuple as the second argument. For each arg in
in_dimshas a corresponding
Optional[int]. It is
Noneif the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer specifying what dimension of the Tensor is being vmapped over.
*args, which is the same as the args to
The return of the vmap staticmethod is a tuple of
(output, out_dims). Similar to
out_dimsshould be of the same structure as
outputand contain one
out_dimper output that specifies if the output has the vmapped dimension and what index it is in.
Please see Extending torch.func with autograd.Function for more details.