PyTorch: Custom nn Modules¶
A fully-connected ReLU network with one hidden layer, trained to predict y from x by minimizing squared Euclidean distance.
This implementation defines the model as a custom Module subclass. Whenever you want a model more complex than a simple sequence of existing Modules you will need to define your model this way.
import torch from torch.autograd import Variable class TwoLayerNet(torch.nn.Module): def __init__(self, D_in, H, D_out): """ In the constructor we instantiate two nn.Linear modules and assign them as member variables. """ super(TwoLayerNet, self).__init__() self.linear1 = torch.nn.Linear(D_in, H) self.linear2 = torch.nn.Linear(H, D_out) def forward(self, x): """ In the forward function we accept a Variable of input data and we must return a Variable of output data. We can use Modules defined in the constructor as well as arbitrary operators on Variables. """ h_relu = self.linear1(x).clamp(min=0) y_pred = self.linear2(h_relu) return y_pred # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 # Create random Tensors to hold inputs and outputs, and wrap them in Variables x = Variable(torch.randn(N, D_in)) y = Variable(torch.randn(N, D_out), requires_grad=False) # Construct our model by instantiating the class defined above model = TwoLayerNet(D_in, H, D_out) # Construct our loss function and an Optimizer. The call to model.parameters() # in the SGD constructor will contain the learnable parameters of the two # nn.Linear modules which are members of the model. criterion = torch.nn.MSELoss(size_average=False) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) for t in range(500): # Forward pass: Compute predicted y by passing x to the model y_pred = model(x) # Compute and print loss loss = criterion(y_pred, y) print(t, loss.data) # Zero gradients, perform a backward pass, and update the weights. optimizer.zero_grad() loss.backward() optimizer.step()
Total running time of the script: ( 0 minutes 0.000 seconds)