Shortcuts

torch.cuda.make_graphed_callables

torch.cuda.make_graphed_callables(callables, sample_args)[source]

Accepts callables (functions or nn.Modules) and returns graphed versions.

Each graphed callable’s forward pass runs its source callable’s forward CUDA work as a CUDA graph inside a single autograd node.

The graphed callable’s forward pass also appends a backward node to the autograd graph. During backward, this node runs the callable’s backward work as a CUDA graph.

Therefore, each graphed callable should be a drop-in replacement for its source callable in an autograd-enabled training loop.

See Partial-network capture for detailed use and constraints.

If you pass a tuple of several callables, their captures will use the same memory pool. See Graph memory management for when this is appropriate.

Parameters
  • callables (torch.nn.Module or Python function, or tuple of these) – Callable or callables to graph. See Graph memory management for when passing a tuple of callables is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order they’ll run in the live workload.

  • sample_args (tuple of Tensors, or tuple of tuples of Tensors) – Samples args for each callable. If a single callable was passed, sample_args must be a single tuple of argument Tensors. If a tuple of callables was passed, sample_args must be tuple of tuples of argument Tensors.

Note

The requires_grad state of each Tensor in sample_args must match the state that’s expected for the corresponding real input in the training loop.

Warning

This API is in beta and may change in future releases.

Warning

sample_args for each callable must be a tuple of Tensors. Other types and keyword args are not allowed.

Warning

Returned callables do not support higher order differentiation (e.g., double backward).

Warning

In any Module passed to make_graphed_callables(), only parameters may be trainable. Buffers must have requires_grad=False.

Warning

After you pass a torch.nn.Module through make_graphed_callables(), you may not add or remove any of that Module’s parameters or buffers.

Warning

torch.nn.Modules passed to make_graphed_callables() must not have module hooks registered on them at the time they are passed. However, registering hooks on modules after passing them through make_graphed_callables() is allowed.

Warning

When running a graphed callable, you must pass its arguments in the same order and format they appeared in that callable’s sample_args.

Warning

All Tensor outputs of graphed callables must require grad.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources