Shortcuts

CudaGraphModule

class tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None)

A cudagraph wrapper for PyTorch callables.

CudaGraphModule is a wrapper class that provides a user-friendly interface to CUDA graphs for PyTorch callables.

Warning

CudaGraphModule is a prototype feature and its API restrictions are likely to change in the future.

This class provides a user-friendly interface to cudagraphs, allowing for a fast, CPU-overhead free execution of operations on GPU. It runs essential checks for the inputs to the function, and gives an nn.Module-like API to run

Warning

This module requires the wrapped function to meet a few requirements. It is the user responsibility to make sure that all of these are fullfilled.

  • The function cannot have dynamic control flow. For instance, the following code snippet will fail to be wrapped in CudaGraphModule:

    >>> def func(x):
    ...     if x.norm() > 1:
    ...         return x + 1
    ...     else:
    ...         return x - 1
    

    Fortunately, PyTorch offers solutions in most cases:

    >>> def func(x):
    ...     return torch.where(x.norm() > 1, x + 1, x - 1)
    
  • The function must execute a code that can be exactly re-run using the same buffers. This means that dynamic shapes (changing shape in the input or during the code execution) is not supported. In other words, the input must have a constant shape.

  • The output of the function must be detached. If a call to the optimizers is required, put it in the input function. For instance, the following function is a valid operator:

    >>> def func(x, y):
    ...     optim.zero_grad()
    ...     loss_val = loss_fn(x, y)
    ...     loss_val.backward()
    ...     optim.step()
    ...     return loss_val.detach()
    
  • The input should not be differntiable. If you need to use nn.Parameters (or differentiable tensors in general), just write a function that uses them as global values rather than passing them as input:

    >>> x = nn.Parameter(torch.randn(()))
    >>> optim = Adam([x], lr=1)
    >>> def func(): # right
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    >>> def func(x): # wrong
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    
  • Args and kwargs that are tensors or tensordict may change (provided that device and shape match), but non-tensor args and kwargs should not change. For instance, if the function receives a string input and the input is changed at any point, the module will silently execute the code with the string used during the capture of the cudagraph. The only supported keyword argument is tensordict_out in case the input is a tensordict.

  • If the module is a TensorDictModuleBase instance and the output id matches the input id, then this identity will be preserved during a call to CudaGraphModule. In all other cases, the output will be cloned, irrespective of whether its elements match or do not match one of the inputs.

Warning

CudaGraphModule is not an Module by design, to discourage gathering parameters of the input module and passing them to an optimizer.

Parameters:
  • module (Callable) – a function that receives tensors (or tensordict) as input and outputs a PyTreeable collection of tensors. If a tensordict is provided, the module can be run with keyword arguments too (see example below).

  • warmup (int, optional) – the number of warmup steps in case the module is compiled (compiled modules should be run a couple of times before being captured by cudagraphs). Defaults to 2 for all modules.

  • in_keys (list of NestedKeys) –

    the input keys, if the module takes a TensorDict as input. Defaults to module.in_keys if this value exists, otherwise None.

    Note

    If in_keys is provided but empty, the module is assumed to receive a tensordict as input. This is sufficient to make CudaGraphModule aware that the function should be treated as a TensorDictModule, but keyword arguments will not be dispatched. See below for some examples.

  • out_keys (list of NestedKeys) – the output keys, if the module takes and outputs TensorDict as output. Defaults to module.out_keys if this value exists, otherwise None.

Examples

>>> # Wrap a simple function
>>> def func(x):
...     return x + 1
>>> func = CudaGraphModule(func)
>>> x = torch.rand((), device='cuda')
>>> out = func(x)
>>> assert isinstance(out, torch.Tensor)
>>> assert out == x+1
>>> # Wrap a tensordict module
>>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"])
>>> func = CudaGraphModule(func)
>>> # This can be called either with a TensorDict or regular keyword arguments alike
>>> y = func(x=x)
>>> td = TensorDict(x=x)
>>> td = func(td)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources