Tracing TensorDictModule ======================== We support tracing execution of :obj:`TensorDictModule` to create FX graphs. Simply import :obj:`symbolic_trace` from ``tensordict.prototype.fx`` instead of ``torch.fx``. .. note:: Support for ``torch.fx`` is highly experimental and subject to change. Use with caution, and raise an issue if you try it out and encounter problems. Tracing a :obj:`TensorDictModule` --------------------------------- We'll illustrate with an example from the overview. We create a :obj:`TensorDictModule`, trace it, and inspect the graph and generated code. .. code-block:: :caption: Tracing a TensorDictModule >>> import torch >>> import torch.nn as nn >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from tensordict.prototype.fx import symbolic_trace >>> class Net(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.LazyLinear(1) ... ... def forward(self, x): ... logits = self.linear(x) ... return logits, torch.sigmoid(logits) >>> module = TensorDictModule( ... Net(), ... in_keys=["input"], ... out_keys=[("outputs", "logits"), ("outputs", "probabilities")], ... ) >>> graph_module = symbolic_trace(module) >>> print(graph_module.graph) graph(): %tensordict : [#users=1] = placeholder[target=tensordict] %getitem : [#users=1] = call_function[target=operator.getitem](args = (%tensordict, input), kwargs = {}) %linear : [#users=2] = call_module[target=linear](args = (%getitem,), kwargs = {}) %sigmoid : [#users=1] = call_function[target=torch.sigmoid](args = (%linear,), kwargs = {}) return (linear, sigmoid) >>> print(graph_module.code) def forward(self, tensordict): getitem = tensordict['input']; tensordict = None linear = self.linear(getitem); getitem = None sigmoid = torch.sigmoid(linear) return (linear, sigmoid) We can check that a forward pass with each module results in the same outputs. >>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32]) >>> module_out = module(tensordict, tensordict_out=TensorDict()) >>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict()) >>> assert ( ... module_out["outputs", "logits"] == graph_module_out["outputs", "logits"] ... ).all() >>> assert ( ... module_out["outputs", "probabilities"] ... == graph_module_out["outputs", "probabilities"] ... ).all() Tracing a :obj:`TensorDictSequential` ------------------------------------- We can also trace :obj:`TensorDictSequential`. In this case the entire execution of the module is traced into a single graph, eliminating intermediate reads and writes on the input :obj:`TensorDict`. We demonstrate by tracing the sequential example from the overview. .. code-block:: :caption: Tracing TensorDictSequential >>> import torch >>> import torch.nn as nn >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> from tensordict.prototype.fx import symbolic_trace >>> class Net(nn.Module): ... def __init__(self, input_size=100, hidden_size=50, output_size=10): ... super().__init__() ... self.fc1 = nn.Linear(input_size, hidden_size) ... self.fc2 = nn.Linear(hidden_size, output_size) ... ... def forward(self, x): ... x = torch.relu(self.fc1(x)) ... return self.fc2(x) ... ... class Masker(nn.Module): ... def forward(self, x, mask): ... return torch.softmax(x * mask, dim=1) >>> net = TensorDictModule( ... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")] ... ) >>> masker = TensorDictModule( ... Masker(), ... in_keys=[("intermediate", "x"), ("input", "mask")], ... out_keys=[("output", "probabilities")], ... ) >>> module = TensorDictSequential(net, masker) >>> graph_module = symbolic_trace(module) >>> print(graph_module.code) def forward(self, tensordict): getitem = tensordict[('input', 'x')] _0_fc1 = getattr(self, "0").module.fc1(getitem); getitem = None relu = torch.relu(_0_fc1); _0_fc1 = None _0_fc2 = getattr(self, "0").module.fc2(relu); relu = None getitem_1 = tensordict[('input', 'mask')]; tensordict = None mul = _0_fc2 * getitem_1; getitem_1 = None softmax = torch.softmax(mul, dim = 1); mul = None return (_0_fc2, softmax) In this case the generated graph and code is a bit more complicated. We can visualize it as follows (requires ``pydot``) .. code-block:: :caption: Visualising the graph >>> from torch.fx.passes.graph_drawer import FxGraphDrawer >>> g = FxGraphDrawer(graph_module, "sequential") >>> with open("graph.svg", "wb") as f: ... f.write(g.get_dot_graph().create_svg()) Which results in the following visualisation .. image:: _static/img/graph.svg :alt: Visualization of the traced graph.