Shortcuts

Warm-up: numpy

A third order polynomial, trained to predict \(y=\sin(x)\) from \(-\pi\) to \(pi\) by minimizing squared Euclidean distance.

This implementation uses numpy to manually compute the forward pass, loss, and backward pass.

A numpy array is a generic n-dimensional array; it does not know anything about deep learning or gradients or computational graphs, and is just a way to perform generic numeric computations.

99 1546.0365145892183
199 1095.8992907857576
299 777.6096767345516
399 552.5352611528602
499 393.3692169177716
599 280.80664270842703
699 201.19883404382793
799 144.89548639133625
899 105.07298180275444
999 76.90616893555887
1099 56.982896477615675
1199 42.89011969702015
1299 32.9212810307606
1399 25.869417918780947
1499 20.880874497663395
1599 17.35185928109817
1699 14.855296144964393
1799 13.08909475122299
1899 11.839566794279794
1999 10.955552494109163
Result: y = 0.04887728059779718 + 0.854335392917212 x + -0.008432144224579809 x^2 + -0.09298823470267067 x^3

import numpy as np
import math

# Create random input and output data
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)

# Randomly initialize weights
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    # y = a + b x + c x^2 + d x^3
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')

Total running time of the script: ( 0 minutes 0.345 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources