Shortcuts

Source code for ts.torch_handler.unit_tests.models.base_model

# pylint: disable=W0622, W0223
# "input" is a built-in, but it's the name used by torch
"""
Simple feed-forward model used only to test BaseHandler
"""

import torch


[docs]class ArgmaxModel(torch.nn.Module):
[docs] def forward(self, *input): return torch.argmax(input[0], 1)
[docs]def save_pt_file(filepath="base_model.pt"): model = ArgmaxModel() torch.save(model.state_dict(), filepath)
if __name__ == "__main__": save_pt_file()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources