# torch.jit.trace_module¶

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch._C.CompilationUnit object>)[source]

Trace a module and return an executable ScriptModule that will be optimized using just-in-time compilation. When a module is passed to torch.jit.trace, only the forward method is run and traced. With trace_module, you can specify a dictionary of method names to example inputs to trace (see the inputs) argument below.

See torch.jit.trace for more information on tracing.

Parameters
• mod (torch.nn.Module) – A torch.nn.Module containing methods whose names are specified in inputs. The given methods will be compiled as a part of a single ScriptModule.

• inputs (dict) – A dict containing sample inputs indexed by method names in mod. The inputs will be passed to methods whose names correspond to inputs’ keys while tracing. { 'forward' : example_forward_input, 'method2': example_method2_input}

Keyword Arguments
• check_trace (bool, optional) – Check if the same inputs run through traced code produce the same outputs. Default: True. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure.

• check_inputs (list of dicts, optional) – A list of dicts of input arguments that should be used to check the trace against what is expected. Each tuple is equivalent to a set of input arguments that would be specified in inputs. For best results, pass in a set of checking inputs representative of the space of shapes and types of inputs you expect the network to see. If not specified, the original inputs are used for checking

• check_tolerance (float, optional) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.

Returns

A ScriptModule object with a single forward method containing the traced code. When func is a torch.nn.Module, the returned ScriptModule will have the same set of sub-modules and parameters as func.

Example (tracing a module with multiple methods):

import torch
import torch.nn as nn

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)

def forward(self, x):
return self.conv(x)

def weighted_kernel_sum(self, weight):
return weight * self.conv.weight

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct ScriptModule with
# a single forward method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces forward) and construct a
# ScriptModule with a single forward method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in inputs), constructs
# a ScriptModule with forward and weighted_kernel_sum methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)