ORTModule¶
The ORTModule class uses the ONNX Runtime to accelerator PyTorch model training.
ORTModule wraps a torch.nn.Module. It offloads the forward and backward pass of a PyTorch training loop to ONNX Runtime. ONNX Runtime uses its optimized computation graph and memory usage to execute these components of the training loop faster with less memory usage.
The following code example illustrates the use of ORTModule in the simple case where the entire model is trained using ONNX Runtime:
# Original PyTorch model
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
...
def forward(self, x):
...
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
model = ORTModule(model) # Only change to original PyTorch script
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# Training Loop is unchanged
for data, target in data_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
ONNX Runtime can also be used to train parts of the model, by wrapping internal torch.nn.Modules with ORTModule.
ORTModule API¶
- class torch_ort.ORTModule(module, debug_options=None)[source]¶
Extends user’s
torch.nn.Module
model to leverage ONNX Runtime super fast training engine.ORTModule specializes the user’s
torch.nn.Module
model, providingforward()
,backward()
along with all otherstorch.nn.Module
’s APIs.- Parameters:
module (torch.nn.Module) – User’s PyTorch module that ORTModule specializes
debug_options (
DebugOptions
, optional) – debugging options for ORTModule.
- forward(*inputs, **kwargs)[source]¶
Delegate the
forward()
pass of PyTorch training to ONNX Runtime.The first call to forward performs setup and checking steps. During this call, ORTModule determines whether the module can be trained with ONNX Runtime. For this reason, the first forward call execution takes longer than subsequent calls. Execution is interupted if ONNX Runtime cannot process the model for training.
- Parameters:
positional (variable) –
positional –
keyword –
forward (and variable keyword arguments defined in the user's PyTorch module's) –
types. (method. Values can be torch tensors and primitive) –
- Returns:
The output as expected from the forward method defined by the user’s PyTorch module. Output values supported include tensors, nested sequences of tensors and nested dictionaries of tensor values.
- add_module(name: str, module: Optional[Module]) None [source]¶
Raises a ORTModuleTorchModelException exception since ORTModule does not support adding modules to it
- property module¶
The original torch.nn.Module that this module wraps.
This property provides access to methods and properties on the original module.
- apply(fn: Callable[[Module], None]) T [source]¶
Override
apply()
to delegate execution to ONNX Runtime
- state_dict(destination=None, prefix='', keep_vars=False)[source]¶
Override
state_dict()
to delegate execution to ONNX Runtime
- load_state_dict(state_dict: OrderedDict[str, Tensor], strict: bool = True)[source]¶
Override
load_state_dict()
to delegate execution to ONNX Runtime
- register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None [source]¶
Override
register_buffer()
- register_parameter(name: str, param: Optional[Parameter]) None [source]¶
Override
register_parameter()
- get_parameter(target: str) Parameter [source]¶
Override
get_parameter()
- get_buffer(target: str) Tensor [source]¶
Override
get_buffer()
- named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]] [source]¶
Override
named_parameters()
- named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Tensor]] [source]¶
Override
named_buffers()
- named_modules(*args, **kwargs)[source]¶
Override
named_modules()