• Docs >
  • Exporting tensordict modules
Shortcuts

Exporting tensordict modules

Author: Vincent Moens

Prerequisites

Reading the TensorDictModule tutorial is preferable to fully benefit from this tutorial.

Once a module has been written using tensordict.nn, it is often useful to isolate the computational graph and export that graph. The goal of this may be to execute the model on hardware (e.g., robots, drones, edge devices) or eliminate the dependency on tensordict altogether.

PyTorch provides multiple methods for exporting modules, including onnx and torch.export, both of which are compatible with tensordict.

In this short tutorial, we will see how one can use torch.export to isolate the computational graph of a model. torch.onnx support follows the same logic.

Key learnings

  • Executing a tensordict.nn module without TensorDict inputs;

  • Selecting the output(s) of a model;

  • Handling stochstic models;

  • Exporting such model using torch.export;

  • Saving the model to a file;

  • Isolating the pytorch model;

import time

import torch
from tensordict.nn import (
    InteractionType,
    NormalParamExtractor,
    ProbabilisticTensorDictModule as Prob,
    set_interaction_type,
    TensorDictModule as Mod,
    TensorDictSequential as Seq,
)
from torch import distributions as dists, nn

Designing the model

In many applications, it is useful to work with stochastic models, i.e., models that output a variable that is not deterministically defined but that is sampled according to a parametric distribution. For instance, generative AI models will often generate different outputs when the same input if provided, because they sample the output based on a distribution which parameters are defined by the input.

The tensordict library deals with this through the ProbabilisticTensorDictModule class. This primitive is built using a distribtion class (Normal in our case) and an indicator of the input keys that will be used at execution time to build that distribution.

The network we are building is therefore going to be the combination of three main components:

  • A network mapping the input to a latent parameter;

  • A tensordict.nn.NormalParamExtractor module splitting the input in a location “loc” and “scale” parameters to be passed to the Normal distrbution;

  • A distribution constructor module.

model = Seq(
    # 1. A small network for embedding
    Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
    Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
    Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
    # 2. Extracting params
    Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
    # 3. Probabilistic module
    Prob(
        in_keys=["loc", "scale"],
        out_keys=["sample"],
        distribution_class=dists.Normal,
    ),
)

Let us run this model and see what the output looks like:

x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.0000, 0.2604, 0.0000, 0.0000]], grad_fn=<ReluBackward0>), tensor([[-0.1580, -0.5222, -0.3319,  0.5519]], grad_fn=<AddmmBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>), tensor([[0.8046, 1.3804]], grad_fn=<ClampMinBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>))

As expected, running the model with a tensor input returns as many tensors as the module’s output keys! For large models, this can be quite annoying and wasteful. Later, we will see how we can limit the number of outputs of the model to deal with this issue.

Using torch.export with a TensorDictModule

Now that we have successfully built our model, we would like to extract its computational graph in a single object that is independent of tensordict. torch.export is a PyTorch module dedicated to isolate the graph of a module and represent it in a standardized way. Its main entry point is export() which returns a ExportedProgram object. In turn, this object has several attributes of interest that we will explore below: a graph_module, which represents the FX graph captured by export, a graph_signature with input, outputs etc of the graph, and finally a module() that returns a callable that can be used in-place of the original module.

Although our module accepts both args and kwargs, we will focus on its usage with kwargs as this is clearer.

from torch.export import export

model_export = export(model, args=(), kwargs={"x": x})

Let us look at the module:

print("module:", model_export.module())
module: GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0]
    getitem_3 = broadcast_tensors[1];  broadcast_tensors = None
    return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

This module can be run exactly like our original module (with a lower overhead):

t0 = time.time()
model(x=x)
print(f"Time for TDModule: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
exported = model_export.module()

# Exported version
t0 = time.time()
exported(x=x)
print(f"Time for exported module: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
Time for TDModule:  469.45 micro-seconds
Time for exported module:  340.70 micro-seconds

and the FX graph:

print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
    def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
         # File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
        linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias);  x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
        relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear);  linear = None
        linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias);  p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
        split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
        getitem: "f32[1, 2]" = split[0]
        getitem_1: "f32[1, 2]" = split[1];  split = None

         # File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
        add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
        softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add);  add = None
        add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
        clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None

         # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
        broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
        getitem_2: "f32[1, 2]" = broadcast_tensors[0]
        getitem_3: "f32[1, 2]" = broadcast_tensors[1];  broadcast_tensors = None
        return (relu, linear_1, getitem_2, getitem_3, getitem_2)

fx graph: class GraphModule(torch.nn.Module):
    def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
         # File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
        linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias);  x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
        relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear);  linear = None
        linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias);  p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
        split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
        getitem: "f32[1, 2]" = split[0]
        getitem_1: "f32[1, 2]" = split[1];  split = None

         # File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
        add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
        softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add);  add = None
        add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
        clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None

         # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
        broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
        getitem_2: "f32[1, 2]" = broadcast_tensors[0]
        getitem_3: "f32[1, 2]" = broadcast_tensors[1];  broadcast_tensors = None
        return (relu, linear_1, getitem_2, getitem_3, getitem_2)

Working with nested keys

Nested keys are a core feature of the tensordict library, and being able to export modules that read and write nested entries is therefore an important feature to support. Because keyword arguments must be regualar strings, it is not possible for dispatch to work directly with them. Instead, dispatch will unpack nested keys joined with a regular underscore (“_”), as the following example shows.

model_nested = Seq(
    Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
    Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))

model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()



def forward(self, some_key):
    some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
    add = torch.ops.aten.add.Tensor(some_key, 1);  some_key = None
    sub = torch.ops.aten.sub.Tensor(add, 1);  add = None
    return pytree.tree_unflatten((sub,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

Note that the callable returned by module() is a pure python callable that can be in turn compiled using compile().

Saving the exported module

torch.export has its own serialization protocol, save() and load(). Conventionally, the “.pt2” extension is to be used:

>>> torch.export.save(model_export, "model.pt2")

Selecting the outputs

Recall that the tensordict.nn is to keep every intermediate value in the output, unless the user specifically asks for only a specific value. During training, this can be very useful: one can easily log intermediate values of the graph, or use them for other purposes (e.g., reconstruct a distribution based on its saved parameters, rather than saving the Distribution object itself). One could also argue that, during training, the impact on memory of registering intermediate values is negligeable since they are part of the computational graph used by torch.autograd to compute the parameter gradients.

During inference, though, we most likely are only interested in the final sample of the model. Because we want to extract the model for usages that are independent of the tensordict library, it makes sense to isolate the only output we desire. To do this, we have several options:

  1. Build the TensorDictSequential() with the selected_out_keys keyword argument, which will induce the selection of the desired entries during calls to the module;

  2. Using the select_out_keys() method, which will modify the out_keys attribute in-place (this can be reverted through reset_out_keys()).

  3. Wrap the existing instance in a TensorDictSequential() that will filter out the unwanted keys:

    >>> module_filtered = Seq(module, selected_out_keys=["sample"])
    

Let us test the model after selecting its output keys. When an x input is provided, we expect our model to output a single tensor corresponding to a sample of the distribution:

model.select_out_keys("sample")
print(model(x=x))
tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>)

We see that the output is now a single tensor, corresponding to the sample of the distribution. We can create a new exported graph from this. Its computational graph should be simplified:

model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0];  broadcast_tensors = None
    return pytree.tree_unflatten((getitem_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

Controlling the Sampling Strategy

We have not yet discussed how the ProbabilisticTensorDictModule samples from the distribution. By sampling, we mean obtaining a value within the space defined by the distribution according to a specific strategy. For instance, one may desire to get stochastic samples during training but deterministic samples (e.g., the mean or the mode) at inference time. To address this, tensordict utilizes the set_interaction_type decorator and context manager, which accepts InteractionType Enum inputs:

>>> with set_interaction_type(InteractionType.MEAN):
...     output = module(input)  # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked

The default InteractionType is InteractionType.DETERMINISTIC, which, if not implemented directly, is either the mean of distributions with a real domain, or the mode of distributions with a discrete domain. This default value can be changed using the default_interaction_type keyword argument of ProbabilisticTensorDictModule.

Let us recap: to control the sampling strategy of our network, we can either define a default sampling strategy in the constructor, or override it at runtime through the set_interaction_type context manager.

As we can see from the following example, torch.export respond correctly the usage of the decorator: if we ask for a random sample, the output is different than if we ask for the mean:

with set_interaction_type(InteractionType.RANDOM):
    model_export = export(model, args=(), kwargs={"x": x})
    print(model_export.module())

with set_interaction_type(InteractionType.MEAN):
    model_export = export(model, args=(), kwargs={"x": x})
    print(model_export.module())
GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0]
    getitem_3 = broadcast_tensors[1];  broadcast_tensors = None
    empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    normal_functional = torch.ops.aten.normal_functional.default(empty);  empty = None
    mul = torch.ops.aten.mul.Tensor(normal_functional, getitem_3);  normal_functional = getitem_3 = None
    add_2 = torch.ops.aten.add.Tensor(getitem_2, mul);  getitem_2 = mul = None
    return pytree.tree_unflatten((add_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0];  broadcast_tensors = None
    return pytree.tree_unflatten((getitem_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

This is all you need to know to use torch.export. Please refer to the official documentation for more info.

Next steps and further reading

  • Check the torch.export tutorial, available here;

  • ONNX support: check the ONNX tutorials to learn more about this feature. Exporting to ONNX is very similar to torch.export explained here.

  • For deployment of PyTorch code on servers without python environment, check the AOTInductor documentation.

Total running time of the script: (0 minutes 1.695 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources