• Docs >
  • Tracing TensorDictModule
Shortcuts

Tracing TensorDictModule

We support tracing execution of TensorDictModule to create FX graphs. Simply import symbolic_trace from tensordict.prototype.fx instead of torch.fx.

Note

Support for torch.fx is highly experimental and subject to change. Use with caution, and raise an issue if you try it out and encounter problems.

Tracing a TensorDictModule

We’ll illustrate with an example from the overview. We create a TensorDictModule, trace it, and inspect the graph and generated code.

Tracing a TensorDictModule
>>> import torch
>>> import torch.nn as nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.prototype.fx import symbolic_trace

>>> class Net(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.LazyLinear(1)
...
...     def forward(self, x):
...         logits = self.linear(x)
...         return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
...     Net(),
...     in_keys=["input"],
...     out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> graph_module = symbolic_trace(module)
>>> print(graph_module.graph)
graph():
    %tensordict : [#users=1] = placeholder[target=tensordict]
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%tensordict, input), kwargs = {})
    %linear : [#users=2] = call_module[target=linear](args = (%getitem,), kwargs = {})
    %sigmoid : [#users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {})
    return (linear, sigmoid)
>>> print(graph_module.code)

def forward(self, tensordict):
    getitem = tensordict['input'];  tensordict = None
    linear = self.linear(getitem);  getitem = None
    sigmoid = torch.sigmoid(linear)
    return (linear, sigmoid)

We can check that a forward pass with each module results in the same outputs.

>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> module_out = module(tensordict, tensordict_out=TensorDict())
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict())
>>> assert (
...     module_out["outputs", "logits"] == graph_module_out["outputs", "logits"]
... ).all()
>>> assert (
...     module_out["outputs", "probabilities"]
...     == graph_module_out["outputs", "probabilities"]
... ).all()

Tracing a TensorDictSequential

We can also trace TensorDictSequential. In this case the entire execution of the module is traced into a single graph, eliminating intermediate reads and writes on the input TensorDict.

We demonstrate by tracing the sequential example from the overview.

Tracing TensorDictSequential
>>> import torch
>>> import torch.nn as nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from tensordict.prototype.fx import symbolic_trace

>>> class Net(nn.Module):
...     def __init__(self, input_size=100, hidden_size=50, output_size=10):
...         super().__init__()
...         self.fc1 = nn.Linear(input_size, hidden_size)
...         self.fc2 = nn.Linear(hidden_size, output_size)
...
...     def forward(self, x):
...         x = torch.relu(self.fc1(x))
...         return self.fc2(x)
...
... class Masker(nn.Module):
...     def forward(self, x, mask):
...         return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
...     Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
...     Masker(),
...     in_keys=[("intermediate", "x"), ("input", "mask")],
...     out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> graph_module = symbolic_trace(module)
>>> print(graph_module.code)

def forward(self, tensordict):
    getitem = tensordict[('input', 'x')]
    _0_fc1 = getattr(self, "0").module.fc1(getitem);  getitem = None
    relu = torch.relu(_0_fc1);  _0_fc1 = None
    _0_fc2 = getattr(self, "0").module.fc2(relu);  relu = None
    getitem_1 = tensordict[('input', 'mask')];  tensordict = None
    mul = _0_fc2 * getitem_1;  getitem_1 = None
    softmax = torch.softmax(mul, dim = 1);  mul = None
    return (_0_fc2, softmax)

In this case the generated graph and code is a bit more complicated. We can visualize it as follows (requires pydot)

Visualising the graph
>>> from torch.fx.passes.graph_drawer import FxGraphDrawer
>>> g = FxGraphDrawer(graph_module, "sequential")
>>> with open("graph.svg", "wb") as f:
...     f.write(g.get_dot_graph().create_svg())

Which results in the following visualisation

Visualization of the traced graph.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources