@torch.jit.script
  def RNN(h, x, W_h, U_h, W_y, b_h, b_y):
    y = []
    for t in range(x.size(0)):
      h = torch.tanh(x[t] @ W_h + h @ U_h + b_h)
      y += [torch.tanh(h @ W_y + b_y)]
      if t % 10 == 0:
        print("stats: ", h.mean(), h.var())
    return torch.stack(y), h

Hybrid Front-End

A new hybrid front-end provides ease-of-use and flexibility in eager mode, while seamlessly transitioning to graph mode for speed, optimization, and functionality in C++ runtime environments.

Distributed Training

Optimize performance in both research and production by taking advantage of native support for asynchronous execution of collective operations and peer-to-peer communication that is accessible from Python and C++.

  import torch.distributed as dist
  from torch.nn.parallel import DistributedDataParallel
  
  dist.init_process_group(backend='gloo')
  model = DistributedDataParallel(model)
  import torch
  import numpy as np
  a = np.ones(5)
  b = torch.from_numpy(a)
  np.add(a, 1, out=a)
  print(a)
  print(b)

Python-First

PyTorch is not a Python binding into a monolithic C++ framework. It’s built to be deeply integrated into Python so it can be used with popular libraries and packages such as Cython and Numba.

Tools & Libraries

An active community of researchers and developers have built a rich ecosystem of tools and libraries for extending PyTorch and supporting development in areas from computer vision to reinforcement learning.

  import torchvision.models as models
  resnet18 = models.resnet18(pretrained=True)
  alexnet = models.alexnet(pretrained=True)
  squeezenet = models.squeezenet1_0(pretrained=True)
  vgg16 = models.vgg16(pretrained=True)
  densenet = models.densenet161(pretrained=True)
  inception = models.inception_v3(pretrained=True)
  import torch.onnx
  import torchvision

  dummy_input = torch.randn(1, 3, 224, 224)
  model = torchvision.models.alexnet(pretrained=True)
  torch.onnx.export(model, dummy_input, "alexnet.onnx")

Native ONNX Support

Export models in the standard ONNX (Open Neural Network Exchange) format for direct access to ONNX-compatible platforms, runtimes, visualizers, and more.

C++ Front-End

The C++ frontend is a pure C++ interface to PyTorch that follows the design and architecture of the established Python frontend. It is intended to enable research in high performance, low latency and bare metal C++ applications.

  #include <torch/torch.h>

  torch::nn::Linear model(num_features, 1);
  torch::optim::SGD optimizer(model->parameters());
  auto data_loader = torch::data::data_loader(dataset);

  for (size_t epoch = 0; epoch < 10; ++epoch) {
    for (auto batch : data_loader) {
      auto prediction = model->forward(batch.data);
      auto loss = loss_function(prediction, batch.target);
      loss.backward();
      optimizer.step();
    }
  }
  export IMAGE_FAMILY="pytorch-latest-cpu"
  export ZONE="us-west1-b"
  export INSTANCE_NAME="my-instance"
  
  gcloud compute instances create $INSTANCE_NAME \
    --zone=$ZONE \
    --image-family=$IMAGE_FAMILY \
    --image-project=deeplearning-platform-release

Cloud Partners

PyTorch is well supported on major cloud platforms, providing frictionless development and easy scaling through prebuilt images, large scale training on GPUs, ability to run models in a production scale environment, and more.