Shortcuts

Warm-up: numpy

A third order polynomial, trained to predict \(y=\sin(x)\) from \(-\pi\) to \(pi\) by minimizing squared Euclidean distance.

This implementation uses numpy to manually compute the forward pass, loss, and backward pass.

A numpy array is a generic n-dimensional array; it does not know anything about deep learning or gradients or computational graphs, and is just a way to perform generic numeric computations.

99 1692.8254041360824
199 1124.617482499394
299 748.2322850621296
399 498.88208491732337
499 333.6698695017706
599 224.19019997645455
699 151.63182392147033
799 103.53587667418915
899 71.6498351382341
999 50.506741620361254
1099 36.484510426654026
1199 27.18304154390989
1299 21.01173359353676
1399 16.916299741141053
1499 14.197822536522715
1599 12.392890094244738
1699 11.194184828093768
1799 10.397865676016934
1899 9.868698825708297
1999 9.516946788756416
Result: y = -0.008723872265741645 + 0.8323058050743684 x + 0.00150501313170153 x^2 + -0.08985471766736489 x^3

import numpy as np
import math

# Create random input and output data
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)

# Randomly initialize weights
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    # y = a + b x + c x^2 + d x^3
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    # Update weights
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')

Total running time of the script: ( 0 minutes 0.746 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources