Shortcuts

ORTModule

The ORTModule class uses the ONNX Runtime to accelerator PyTorch model training.

ORTModule wraps a torch.nn.Module. It offloads the forward and backward pass of a PyTorch training loop to ONNX Runtime. ONNX Runtime uses its optimized computation graph and memory usage to execute these components of the training loop faster with less memory usage.

The following code example illustrates the use of ORTModule in the simple case where the entire model is trained using ONNX Runtime:

# Original PyTorch model
classNeuralNet(torch.nn.Module):
    def__init__(self,input_size,hidden_size,num_classes):
        ...
    defforward(self,x):
        ...

model=NeuralNet(input_size=784,hidden_size=500,num_classes=10)
model = ORTModule(model) # Only change to original PyTorch script
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=1e-4)

# Training Loop is unchanged
fordata,targetindata_loader:
    optimizer.zero_grad()
    output=model(data)
    loss=criterion(output,target)
    loss.backward()
    optimizer.step()

ONNX Runtime can also be used to train parts of the model, by wrapping internal torch.nn.Modules with ORTModule.

ORTModule API

class torch_ort.ORTModule(module, debug_options=None)[source]

Extends user’s torch.nn.Module model to leverage ONNX Runtime super fast training engine.

ORTModule specializes the user’s torch.nn.Module model, providing forward(), backward() along with all others torch.nn.Module’s APIs.

Parameters:
  • module (torch.nn.Module) – User’s PyTorch module that ORTModule specializes

  • debug_options (DebugOptions, optional) – debugging options for ORTModule.

forward(*inputs, **kwargs)[source]

Delegate the forward() pass of PyTorch training to ONNX Runtime.

The first call to forward performs setup and checking steps. During this call, ORTModule determines whether the module can be trained with ONNX Runtime. For this reason, the first forward call execution takes longer than subsequent calls. Execution is interupted if ONNX Runtime cannot process the model for training.

Parameters:
  • positional (variable) –

  • positional

  • keyword

  • forward (and variable keyword arguments defined in the user's PyTorch module's) –

  • types. (method. Values can be torch tensors and primitive) –

Returns:

The output as expected from the forward method defined by the user’s PyTorch module. Output values supported include tensors, nested sequences of tensors and nested dictionaries of tensor values.

add_module(name: str, module: Optional[Module]) None[source]

Raises a ORTModuleTorchModelException exception since ORTModule does not support adding modules to it

property module

The original torch.nn.Module that this module wraps.

This property provides access to methods and properties on the original module.

apply(fn: Callable[[Module], None]) T[source]

Override apply() to delegate execution to ONNX Runtime

train(mode: bool = True) T[source]

Override train() to delegate execution to ONNX Runtime

state_dict(destination=None, prefix='', keep_vars=False)[source]

Override state_dict() to delegate execution to ONNX Runtime

load_state_dict(state_dict: OrderedDict[str, Tensor], strict: bool = True)[source]

Override load_state_dict() to delegate execution to ONNX Runtime

register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None[source]

Override register_buffer()

register_parameter(name: str, param: Optional[Parameter]) None[source]

Override register_parameter()

get_parameter(target: str) Parameter[source]

Override get_parameter()

get_buffer(target: str) Tensor[source]

Override get_buffer()

parameters(recurse: bool = True) Iterator[Parameter][source]

Override parameters()

named_parameters(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Parameter]][source]

Override named_parameters()

buffers(recurse: bool = True) Iterator[Tensor][source]

Override buffers()

named_buffers(prefix: str = '', recurse: bool = True) Iterator[Tuple[str, Tensor]][source]

Override named_buffers()

named_children() Iterator[Tuple[str, Module]][source]

Override named_children()

modules() Iterator[Module][source]

Override modules()

named_modules(*args, **kwargs)[source]

Override named_modules()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources