CudaGraphModule¶
- class tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None)¶
A cudagraph wrapper for PyTorch callables.
CudaGraphModule
is a wrapper class that provides a user-friendly interface to CUDA graphs for PyTorch callables.Warning
CudaGraphModule
is a prototype feature and its API restrictions are likely to change in the future.This class provides a user-friendly interface to cudagraphs, allowing for a fast, CPU-overhead free execution of operations on GPU. It runs essential checks for the inputs to the function, and gives an nn.Module-like API to run
Warning
This module requires the wrapped function to meet a few requirements. It is the user responsibility to make sure that all of these are fullfilled.
The function cannot have dynamic control flow. For instance, the following code snippet will fail to be wrapped in CudaGraphModule:
>>> def func(x): ... if x.norm() > 1: ... return x + 1 ... else: ... return x - 1
Fortunately, PyTorch offers solutions in most cases:
>>> def func(x): ... return torch.where(x.norm() > 1, x + 1, x - 1)
The function must execute a code that can be exactly re-run using the same buffers. This means that dynamic shapes (changing shape in the input or during the code execution) is not supported. In other words, the input must have a constant shape.
The output of the function must be detached. If a call to the optimizers is required, put it in the input function. For instance, the following function is a valid operator:
>>> def func(x, y): ... optim.zero_grad() ... loss_val = loss_fn(x, y) ... loss_val.backward() ... optim.step() ... return loss_val.detach()
The input should not be differntiable. If you need to use nn.Parameters (or differentiable tensors in general), just write a function that uses them as global values rather than passing them as input:
>>> x = nn.Parameter(torch.randn(())) >>> optim = Adam([x], lr=1) >>> def func(): # right ... optim.zero_grad() ... (x+1).backward() ... optim.step() >>> def func(x): # wrong ... optim.zero_grad() ... (x+1).backward() ... optim.step()
Args and kwargs that are tensors or tensordict may change (provided that device and shape match), but non-tensor args and kwargs should not change. For instance, if the function receives a string input and the input is changed at any point, the module will silently execute the code with the string used during the capture of the cudagraph. The only supported keyword argument is tensordict_out in case the input is a tensordict.
If the module is a
TensorDictModuleBase
instance and the output id matches the input id, then this identity will be preserved during a call toCudaGraphModule
. In all other cases, the output will be cloned, irrespective of whether its elements match or do not match one of the inputs.
Warning
CudaGraphModule
is not anModule
by design, to discourage gathering parameters of the input module and passing them to an optimizer.- Parameters:
module (Callable) – a function that receives tensors (or tensordict) as input and outputs a PyTreeable collection of tensors. If a tensordict is provided, the module can be run with keyword arguments too (see example below).
warmup (int, optional) – the number of warmup steps in case the module is compiled (compiled modules should be run a couple of times before being captured by cudagraphs). Defaults to
2
for all modules.in_keys (list of NestedKeys) –
the input keys, if the module takes a TensorDict as input. Defaults to
module.in_keys
if this value exists, otherwiseNone
.Note
If
in_keys
is provided but empty, the module is assumed to receive a tensordict as input. This is sufficient to makeCudaGraphModule
aware that the function should be treated as a TensorDictModule, but keyword arguments will not be dispatched. See below for some examples.out_keys (list of NestedKeys) – the output keys, if the module takes and outputs TensorDict as output. Defaults to
module.out_keys
if this value exists, otherwiseNone
.
Examples
>>> # Wrap a simple function >>> def func(x): ... return x + 1 >>> func = CudaGraphModule(func) >>> x = torch.rand((), device='cuda') >>> out = func(x) >>> assert isinstance(out, torch.Tensor) >>> assert out == x+1 >>> # Wrap a tensordict module >>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]) >>> func = CudaGraphModule(func) >>> # This can be called either with a TensorDict or regular keyword arguments alike >>> y = func(x=x) >>> td = TensorDict(x=x) >>> td = func(td)