Shortcuts

DCGAN Tutorial

Author: Nathan Inkawhich

Introduction

This tutorial will give an introduction to DCGANs through an example. We will train a generative adversarial network (GAN) to generate new celebrities after showing it pictures of many real celebrities. Most of the code here is from the dcgan implementation in pytorch/examples, and this document will give a thorough explanation of the implementation and shed light on how and why this model works. But don’t worry, no prior knowledge of GANs is required, but it may require a first-timer to spend some time reasoning about what is actually happening under the hood. Also, for the sake of time it will help to have a GPU, or two. Lets start from the beginning.

Generative Adversarial Networks

What is a GAN?

GANs are a framework for teaching a DL model to capture the training data’s distribution so we can generate new data from that same distribution. GANs were invented by Ian Goodfellow in 2014 and first described in the paper Generative Adversarial Nets. They are made of two distinct models, a generator and a discriminator. The job of the generator is to spawn ‘fake’ images that look like the training images. The job of the discriminator is to look at an image and output whether or not it is a real training image or a fake image from the generator. During training, the generator is constantly trying to outsmart the discriminator by generating better and better fakes, while the discriminator is working to become a better detective and correctly classify the real and fake images. The equilibrium of this game is when the generator is generating perfect fakes that look as if they came directly from the training data, and the discriminator is left to always guess at 50% confidence that the generator output is real or fake.

Now, lets define some notation to be used throughout tutorial starting with the discriminator. Let \(x\) be data representing an image. \(D(x)\) is the discriminator network which outputs the (scalar) probability that \(x\) came from training data rather than the generator. Here, since we are dealing with images, the input to \(D(x)\) is an image of CHW size 3x64x64. Intuitively, \(D(x)\) should be HIGH when \(x\) comes from training data and LOW when \(x\) comes from the generator. \(D(x)\) can also be thought of as a traditional binary classifier.

For the generator’s notation, let \(z\) be a latent space vector sampled from a standard normal distribution. \(G(z)\) represents the generator function which maps the latent vector \(z\) to data-space. The goal of \(G\) is to estimate the distribution that the training data comes from (\(p_{data}\)) so it can generate fake samples from that estimated distribution (\(p_g\)).

So, \(D(G(z))\) is the probability (scalar) that the output of the generator \(G\) is a real image. As described in Goodfellow’s paper, \(D\) and \(G\) play a minimax game in which \(D\) tries to maximize the probability it correctly classifies reals and fakes (\(logD(x)\)), and \(G\) tries to minimize the probability that \(D\) will predict its outputs are fake (\(log(1-D(G(z)))\)). From the paper, the GAN loss function is

\[\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] \]

In theory, the solution to this minimax game is where \(p_g = p_{data}\), and the discriminator guesses randomly if the inputs are real or fake. However, the convergence theory of GANs is still being actively researched and in reality models do not always train to this point.

What is a DCGAN?

A DCGAN is a direct extension of the GAN described above, except that it explicitly uses convolutional and convolutional-transpose layers in the discriminator and generator, respectively. It was first described by Radford et. al. in the paper Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks. The discriminator is made up of strided convolution layers, batch norm layers, and LeakyReLU activations. The input is a 3x64x64 input image and the output is a scalar probability that the input is from the real data distribution. The generator is comprised of convolutional-transpose layers, batch norm layers, and ReLU activations. The input is a latent vector, \(z\), that is drawn from a standard normal distribution and the output is a 3x64x64 RGB image. The strided conv-transpose layers allow the latent vector to be transformed into a volume with the same shape as an image. In the paper, the authors also give some tips about how to setup the optimizers, how to calculate the loss functions, and how to initialize the model weights, all of which will be explained in the coming sections.

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Out:

Random Seed:  999

Inputs

Let’s define some inputs for the run:

  • dataroot - the path to the root of the dataset folder. We will talk more about the dataset in the next section
  • workers - the number of worker threads for loading the data with the DataLoader
  • batch_size - the batch size used in training. The DCGAN paper uses a batch size of 128
  • image_size - the spatial size of the images used for training. This implementation defaults to 64x64. If another size is desired, the structures of D and G must be changed. See here for more details
  • nc - number of color channels in the input images. For color images this is 3
  • nz - length of latent vector
  • ngf - relates to the depth of feature maps carried through the generator
  • ndf - sets the depth of feature maps propagated through the discriminator
  • num_epochs - number of training epochs to run. Training for longer will probably lead to better results but will also take much longer
  • lr - learning rate for training. As described in the DCGAN paper, this number should be 0.0002
  • beta1 - beta1 hyperparameter for Adam optimizers. As described in paper, this number should be 0.5
  • ngpu - number of GPUs available. If this is 0, code will run in CPU mode. If this number is greater than 0 it will run on that number of GPUs
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

Data

In this tutorial we will use the Celeb-A Faces dataset which can be downloaded at the linked site, or in Google Drive. The dataset will download as a file named img_align_celeba.zip. Once downloaded, create a directory named celeba and extract the zip file into that directory. Then, set the dataroot input for this notebook to the celeba directory you just created. The resulting directory structure should be:

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

This is an important step because we will be using the ImageFolder dataset class, which requires there to be subdirectories in the dataset’s root folder. Now, we can create the dataset, create the dataloader, set the device to run on, and finally visualize some of the training data.

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
../_images/sphx_glr_dcgan_faces_tutorial_001.png

Implementation

With our input parameters set and the dataset prepared, we can now get into the implementation. We will start with the weight initialization strategy, then talk about the generator, discriminator, loss functions, and training loop in detail.

Weight Initialization

From the DCGAN paper, the authors specify that all model weights shall be randomly initialized from a Normal distribution with mean=0, stdev=0.02. The weights_init function takes an initialized model as input and reinitializes all convolutional, convolutional-transpose, and batch normalization layers to meet this criteria. This function is applied to the models immediately after initialization.

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Generator

The generator, \(G\), is designed to map the latent space vector (\(z\)) to data-space. Since our data are images, converting \(z\) to data-space means ultimately creating a RGB image with the same size as the training images (i.e. 3x64x64). In practice, this is accomplished through a series of strided two dimensional convolutional transpose layers, each paired with a 2d batch norm layer and a relu activation. The output of the generator is fed through a tanh function to return it to the input data range of \([-1,1]\). It is worth noting the existence of the batch norm functions after the conv-transpose layers, as this is a critical contribution of the DCGAN paper. These layers help with the flow of gradients during training. An image of the generator from the DCGAN paper is shown below.

dcgan_generator

Notice, the how the inputs we set in the input section (nz, ngf, and nc) influence the generator architecture in code. nz is the length of the z input vector, ngf relates to the size of the feature maps that are propagated through the generator, and nc is the number of channels in the output image (set to 3 for RGB images). Below is the code for the generator.

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

Now, we can instantiate the generator and apply the weights_init function. Check out the printed model to see how the generator object is structured.

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

# Print the model
print(netG)

Out:

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

Discriminator

As mentioned, the discriminator, \(D\), is a binary classification network that takes an image as input and outputs a scalar probability that the input image is real (as opposed to fake). Here, \(D\) takes a 3x64x64 input image, processes it through a series of Conv2d, BatchNorm2d, and LeakyReLU layers, and outputs the final probability through a Sigmoid activation function. This architecture can be extended with more layers if necessary for the problem, but there is significance to the use of the strided convolution, BatchNorm, and LeakyReLUs. The DCGAN paper mentions it is a good practice to use strided convolution rather than pooling to downsample because it lets the network learn its own pooling function. Also batch norm and leaky relu functions promote healthy gradient flow which is critical for the learning process of both \(G\) and \(D\).

Discriminator Code

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

Now, as with the generator, we can create the discriminator, apply the weights_init function, and print the model’s structure.

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

Out:

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

Loss Functions and Optimizers

With \(D\) and \(G\) setup, we can specify how they learn through the loss functions and optimizers. We will use the Binary Cross Entropy loss (BCELoss) function which is defined in PyTorch as:

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] \]

Notice how this function provides the calculation of both log components in the objective function (i.e. \(log(D(x))\) and \(log(1-D(G(z)))\)). We can specify what part of the BCE equation to use with the \(y\) input. This is accomplished in the training loop which is coming up soon, but it is important to understand how we can choose which component we wish to calculate just by changing \(y\) (i.e. GT labels).

Next, we define our real label as 1 and the fake label as 0. These labels will be used when calculating the losses of \(D\) and \(G\), and this is also the convention used in the original GAN paper. Finally, we set up two separate optimizers, one for \(D\) and one for \(G\). As specified in the DCGAN paper, both are Adam optimizers with learning rate 0.0002 and Beta1 = 0.5. For keeping track of the generator’s learning progression, we will generate a fixed batch of latent vectors that are drawn from a Gaussian distribution (i.e. fixed_noise) . In the training loop, we will periodically input this fixed_noise into \(G\), and over the iterations we will see images form out of the noise.

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

Training

Finally, now that we have all of the parts of the GAN framework defined, we can train it. Be mindful that training GANs is somewhat of an art form, as incorrect hyperparameter settings lead to mode collapse with little explanation of what went wrong. Here, we will closely follow Algorithm 1 from Goodfellow’s paper, while abiding by some of the best practices shown in ganhacks. Namely, we will “construct different mini-batches for real and fake” images, and also adjust G’s objective function to maximize \(logD(G(z))\). Training is split up into two main parts. Part 1 updates the Discriminator and Part 2 updates the Generator.

Part 1 - Train the Discriminator

Recall, the goal of training the discriminator is to maximize the probability of correctly classifying a given input as real or fake. In terms of Goodfellow, we wish to “update the discriminator by ascending its stochastic gradient”. Practically, we want to maximize \(log(D(x)) + log(1-D(G(z)))\). Due to the separate mini-batch suggestion from ganhacks, we will calculate this in two steps. First, we will construct a batch of real samples from the training set, forward pass through \(D\), calculate the loss (\(log(D(x))\)), then calculate the gradients in a backward pass. Secondly, we will construct a batch of fake samples with the current generator, forward pass this batch through \(D\), calculate the loss (\(log(1-D(G(z)))\)), and accumulate the gradients with a backward pass. Now, with the gradients accumulated from both the all-real and all-fake batches, we call a step of the Discriminator’s optimizer.

Part 2 - Train the Generator

As stated in the original paper, we want to train the Generator by minimizing \(log(1-D(G(z)))\) in an effort to generate better fakes. As mentioned, this was shown by Goodfellow to not provide sufficient gradients, especially early in the learning process. As a fix, we instead wish to maximize \(log(D(G(z)))\). In the code we accomplish this by: classifying the Generator output from Part 1 with the Discriminator, computing G’s loss using real labels as GT, computing G’s gradients in a backward pass, and finally updating G’s parameters with an optimizer step. It may seem counter-intuitive to use the real labels as GT labels for the loss function, but this allows us to use the \(log(x)\) part of the BCELoss (rather than the \(log(1-x)\) part) which is exactly what we want.

Finally, we will do some statistic reporting and at the end of each epoch we will push our fixed_noise batch through the generator to visually track the progress of G’s training. The training statistics reported are:

  • Loss_D - discriminator loss calculated as the sum of losses for the all real and all fake batches (\(log(D(x)) + log(1 - D(G(z)))\)).
  • Loss_G - generator loss calculated as \(log(D(G(z)))\)
  • D(x) - the average output (across the batch) of the discriminator for the all real batch. This should start close to 1 then theoretically converge to 0.5 when G gets better. Think about why this is.
  • D(G(z)) - average discriminator outputs for the all fake batch. The first number is before D is updated and the second number is after D is updated. These numbers should start near 0 and converge to 0.5 as G gets better. Think about why this is.

Note: This step might take a while, depending on how many epochs you run and if you removed some data from the dataset.

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

Out:

Starting Training Loop...
[0/5][0/1583]   Loss_D: 2.0036  Loss_G: 5.6241  D(x): 0.5882    D(G(z)): 0.6680 / 0.0059
[0/5][50/1583]  Loss_D: 0.1797  Loss_G: 12.4328 D(x): 0.9258    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.4256  Loss_G: 6.0486  D(x): 0.9191    D(G(z)): 0.1834 / 0.0068
[0/5][150/1583] Loss_D: 0.4998  Loss_G: 6.0498  D(x): 0.9026    D(G(z)): 0.2468 / 0.0080
[0/5][200/1583] Loss_D: 0.4657  Loss_G: 5.4430  D(x): 0.9495    D(G(z)): 0.2902 / 0.0081
[0/5][250/1583] Loss_D: 0.9682  Loss_G: 10.7051 D(x): 0.9145    D(G(z)): 0.4947 / 0.0001
[0/5][300/1583] Loss_D: 0.7617  Loss_G: 4.7044  D(x): 0.5641    D(G(z)): 0.0201 / 0.0246
[0/5][350/1583] Loss_D: 0.4248  Loss_G: 4.9143  D(x): 0.8529    D(G(z)): 0.1904 / 0.0137
[0/5][400/1583] Loss_D: 0.8801  Loss_G: 2.6868  D(x): 0.5446    D(G(z)): 0.0301 / 0.1184
[0/5][450/1583] Loss_D: 0.2093  Loss_G: 4.0865  D(x): 0.9036    D(G(z)): 0.0669 / 0.0364
[0/5][500/1583] Loss_D: 1.1730  Loss_G: 1.6556  D(x): 0.4904    D(G(z)): 0.1328 / 0.2834
[0/5][550/1583] Loss_D: 0.4866  Loss_G: 6.3008  D(x): 0.8779    D(G(z)): 0.2422 / 0.0054
[0/5][600/1583] Loss_D: 0.5533  Loss_G: 4.7421  D(x): 0.6743    D(G(z)): 0.0240 / 0.0212
[0/5][650/1583] Loss_D: 0.4075  Loss_G: 4.1660  D(x): 0.7950    D(G(z)): 0.1076 / 0.0393
[0/5][700/1583] Loss_D: 0.8426  Loss_G: 12.1846 D(x): 0.9453    D(G(z)): 0.4783 / 0.0001
[0/5][750/1583] Loss_D: 0.5913  Loss_G: 2.7238  D(x): 0.6840    D(G(z)): 0.0823 / 0.1105
[0/5][800/1583] Loss_D: 0.5156  Loss_G: 3.2862  D(x): 0.6760    D(G(z)): 0.0390 / 0.0691
[0/5][850/1583] Loss_D: 0.4172  Loss_G: 8.1879  D(x): 0.9791    D(G(z)): 0.2891 / 0.0009
[0/5][900/1583] Loss_D: 1.5890  Loss_G: 6.1767  D(x): 0.9303    D(G(z)): 0.6324 / 0.0085
[0/5][950/1583] Loss_D: 0.2808  Loss_G: 6.2957  D(x): 0.9032    D(G(z)): 0.1355 / 0.0056
[0/5][1000/1583]        Loss_D: 0.2561  Loss_G: 5.6035  D(x): 0.9548    D(G(z)): 0.1656 / 0.0074
[0/5][1050/1583]        Loss_D: 0.3673  Loss_G: 3.8559  D(x): 0.7903    D(G(z)): 0.0567 / 0.0394
[0/5][1100/1583]        Loss_D: 0.8489  Loss_G: 5.6767  D(x): 0.8093    D(G(z)): 0.3668 / 0.0084
[0/5][1150/1583]        Loss_D: 0.5824  Loss_G: 4.7859  D(x): 0.8767    D(G(z)): 0.2963 / 0.0158
[0/5][1200/1583]        Loss_D: 0.5125  Loss_G: 4.1138  D(x): 0.7164    D(G(z)): 0.0559 / 0.0376
[0/5][1250/1583]        Loss_D: 0.6284  Loss_G: 5.7596  D(x): 0.8330    D(G(z)): 0.2847 / 0.0070
[0/5][1300/1583]        Loss_D: 0.4273  Loss_G: 5.4463  D(x): 0.9060    D(G(z)): 0.2257 / 0.0101
[0/5][1350/1583]        Loss_D: 0.7720  Loss_G: 2.2033  D(x): 0.6093    D(G(z)): 0.0651 / 0.1719
[0/5][1400/1583]        Loss_D: 0.6578  Loss_G: 3.7349  D(x): 0.7482    D(G(z)): 0.2184 / 0.0387
[0/5][1450/1583]        Loss_D: 0.4868  Loss_G: 3.7652  D(x): 0.7639    D(G(z)): 0.1184 / 0.0413
[0/5][1500/1583]        Loss_D: 0.3773  Loss_G: 4.5281  D(x): 0.7796    D(G(z)): 0.0465 / 0.0255
[0/5][1550/1583]        Loss_D: 0.3768  Loss_G: 3.3804  D(x): 0.8373    D(G(z)): 0.1334 / 0.0549
[1/5][0/1583]   Loss_D: 0.6207  Loss_G: 3.3850  D(x): 0.7411    D(G(z)): 0.1859 / 0.0544
[1/5][50/1583]  Loss_D: 0.6158  Loss_G: 3.5835  D(x): 0.7484    D(G(z)): 0.1848 / 0.0534
[1/5][100/1583] Loss_D: 0.5375  Loss_G: 4.5179  D(x): 0.8992    D(G(z)): 0.3112 / 0.0171
[1/5][150/1583] Loss_D: 0.8392  Loss_G: 0.8935  D(x): 0.5613    D(G(z)): 0.0610 / 0.4846
[1/5][200/1583] Loss_D: 0.4394  Loss_G: 2.8901  D(x): 0.7625    D(G(z)): 0.0926 / 0.1052
[1/5][250/1583] Loss_D: 0.2638  Loss_G: 4.3459  D(x): 0.8909    D(G(z)): 0.1168 / 0.0270
[1/5][300/1583] Loss_D: 0.8095  Loss_G: 5.9970  D(x): 0.9175    D(G(z)): 0.4576 / 0.0055
[1/5][350/1583] Loss_D: 0.3966  Loss_G: 4.8993  D(x): 0.9123    D(G(z)): 0.2255 / 0.0134
[1/5][400/1583] Loss_D: 0.7719  Loss_G: 1.9912  D(x): 0.6141    D(G(z)): 0.1040 / 0.1868
[1/5][450/1583] Loss_D: 0.3461  Loss_G: 3.9002  D(x): 0.8587    D(G(z)): 0.1492 / 0.0334
[1/5][500/1583] Loss_D: 0.3796  Loss_G: 2.7830  D(x): 0.7828    D(G(z)): 0.0688 / 0.0932
[1/5][550/1583] Loss_D: 0.5169  Loss_G: 4.2899  D(x): 0.9182    D(G(z)): 0.3163 / 0.0208
[1/5][600/1583] Loss_D: 0.5274  Loss_G: 5.4805  D(x): 0.9095    D(G(z)): 0.2982 / 0.0074
[1/5][650/1583] Loss_D: 0.4975  Loss_G: 4.2649  D(x): 0.8526    D(G(z)): 0.2341 / 0.0230
[1/5][700/1583] Loss_D: 0.5619  Loss_G: 3.7362  D(x): 0.9121    D(G(z)): 0.3257 / 0.0391
[1/5][750/1583] Loss_D: 0.8639  Loss_G: 1.5225  D(x): 0.5237    D(G(z)): 0.0603 / 0.2767
[1/5][800/1583] Loss_D: 0.9482  Loss_G: 6.4839  D(x): 0.9094    D(G(z)): 0.4814 / 0.0039
[1/5][850/1583] Loss_D: 0.3307  Loss_G: 2.9769  D(x): 0.7966    D(G(z)): 0.0597 / 0.0731
[1/5][900/1583] Loss_D: 0.8658  Loss_G: 5.2396  D(x): 0.9707    D(G(z)): 0.4620 / 0.0152
[1/5][950/1583] Loss_D: 0.4835  Loss_G: 4.1915  D(x): 0.8662    D(G(z)): 0.2584 / 0.0219
[1/5][1000/1583]        Loss_D: 1.1729  Loss_G: 2.9327  D(x): 0.4329    D(G(z)): 0.0190 / 0.1071
[1/5][1050/1583]        Loss_D: 0.7550  Loss_G: 2.2552  D(x): 0.5589    D(G(z)): 0.0310 / 0.1438
[1/5][1100/1583]        Loss_D: 0.5929  Loss_G: 4.1725  D(x): 0.8967    D(G(z)): 0.3368 / 0.0265
[1/5][1150/1583]        Loss_D: 0.5670  Loss_G: 2.1012  D(x): 0.6633    D(G(z)): 0.0746 / 0.1587
[1/5][1200/1583]        Loss_D: 0.6381  Loss_G: 5.3362  D(x): 0.9125    D(G(z)): 0.3711 / 0.0073
[1/5][1250/1583]        Loss_D: 0.5557  Loss_G: 3.7354  D(x): 0.8108    D(G(z)): 0.2505 / 0.0360
[1/5][1300/1583]        Loss_D: 0.4166  Loss_G: 2.6690  D(x): 0.7816    D(G(z)): 0.1173 / 0.0941
[1/5][1350/1583]        Loss_D: 0.5063  Loss_G: 2.8805  D(x): 0.7929    D(G(z)): 0.1969 / 0.0809
[1/5][1400/1583]        Loss_D: 0.9375  Loss_G: 3.6970  D(x): 0.9106    D(G(z)): 0.4750 / 0.0511
[1/5][1450/1583]        Loss_D: 0.3749  Loss_G: 3.6325  D(x): 0.8832    D(G(z)): 0.1969 / 0.0385
[1/5][1500/1583]        Loss_D: 0.6583  Loss_G: 2.5942  D(x): 0.7845    D(G(z)): 0.2794 / 0.1039
[1/5][1550/1583]        Loss_D: 1.5500  Loss_G: 5.4623  D(x): 0.9664    D(G(z)): 0.6994 / 0.0080
[2/5][0/1583]   Loss_D: 0.3586  Loss_G: 3.2352  D(x): 0.8119    D(G(z)): 0.1101 / 0.0616
[2/5][50/1583]  Loss_D: 0.5593  Loss_G: 3.8964  D(x): 0.8711    D(G(z)): 0.2995 / 0.0318
[2/5][100/1583] Loss_D: 0.9296  Loss_G: 4.8743  D(x): 0.9362    D(G(z)): 0.5245 / 0.0125
[2/5][150/1583] Loss_D: 0.5999  Loss_G: 1.9848  D(x): 0.6603    D(G(z)): 0.0988 / 0.1884
[2/5][200/1583] Loss_D: 0.5168  Loss_G: 2.3519  D(x): 0.7875    D(G(z)): 0.1903 / 0.1293
[2/5][250/1583] Loss_D: 0.6811  Loss_G: 2.3089  D(x): 0.7496    D(G(z)): 0.2697 / 0.1339
[2/5][300/1583] Loss_D: 0.6645  Loss_G: 2.5562  D(x): 0.7672    D(G(z)): 0.2812 / 0.1029
[2/5][350/1583] Loss_D: 0.6678  Loss_G: 1.4111  D(x): 0.6518    D(G(z)): 0.1496 / 0.2927
[2/5][400/1583] Loss_D: 0.8037  Loss_G: 4.0916  D(x): 0.8787    D(G(z)): 0.4302 / 0.0264
[2/5][450/1583] Loss_D: 0.6399  Loss_G: 3.1521  D(x): 0.8343    D(G(z)): 0.3236 / 0.0616
[2/5][500/1583] Loss_D: 0.5165  Loss_G: 2.1854  D(x): 0.6904    D(G(z)): 0.0832 / 0.1535
[2/5][550/1583] Loss_D: 0.6232  Loss_G: 4.6803  D(x): 0.9312    D(G(z)): 0.3831 / 0.0153
[2/5][600/1583] Loss_D: 0.6268  Loss_G: 2.1716  D(x): 0.6753    D(G(z)): 0.1426 / 0.1571
[2/5][650/1583] Loss_D: 0.6157  Loss_G: 1.7105  D(x): 0.7040    D(G(z)): 0.1871 / 0.2270
[2/5][700/1583] Loss_D: 0.6188  Loss_G: 1.8903  D(x): 0.6810    D(G(z)): 0.1556 / 0.1789
[2/5][750/1583] Loss_D: 0.6617  Loss_G: 2.3828  D(x): 0.7774    D(G(z)): 0.2735 / 0.1377
[2/5][800/1583] Loss_D: 0.5570  Loss_G: 3.0554  D(x): 0.8652    D(G(z)): 0.3040 / 0.0625
[2/5][850/1583] Loss_D: 0.3772  Loss_G: 2.3553  D(x): 0.8150    D(G(z)): 0.1402 / 0.1238
[2/5][900/1583] Loss_D: 1.9095  Loss_G: 6.0321  D(x): 0.9360    D(G(z)): 0.7515 / 0.0060
[2/5][950/1583] Loss_D: 0.5147  Loss_G: 2.0359  D(x): 0.7689    D(G(z)): 0.1853 / 0.1635
[2/5][1000/1583]        Loss_D: 0.7283  Loss_G: 2.0252  D(x): 0.5804    D(G(z)): 0.0771 / 0.1752
[2/5][1050/1583]        Loss_D: 0.8640  Loss_G: 3.8369  D(x): 0.8103    D(G(z)): 0.4249 / 0.0327
[2/5][1100/1583]        Loss_D: 0.7113  Loss_G: 3.7684  D(x): 0.9077    D(G(z)): 0.4257 / 0.0307
[2/5][1150/1583]        Loss_D: 0.7378  Loss_G: 2.8348  D(x): 0.7623    D(G(z)): 0.3230 / 0.0702
[2/5][1200/1583]        Loss_D: 0.4492  Loss_G: 2.2788  D(x): 0.7369    D(G(z)): 0.0937 / 0.1345
[2/5][1250/1583]        Loss_D: 0.6655  Loss_G: 3.1927  D(x): 0.8753    D(G(z)): 0.3687 / 0.0523
[2/5][1300/1583]        Loss_D: 0.8506  Loss_G: 1.5262  D(x): 0.5269    D(G(z)): 0.0791 / 0.2679
[2/5][1350/1583]        Loss_D: 0.8017  Loss_G: 2.9167  D(x): 0.8723    D(G(z)): 0.4338 / 0.0796
[2/5][1400/1583]        Loss_D: 0.5613  Loss_G: 2.9290  D(x): 0.8254    D(G(z)): 0.2757 / 0.0693
[2/5][1450/1583]        Loss_D: 0.4857  Loss_G: 2.6285  D(x): 0.8080    D(G(z)): 0.2086 / 0.0918
[2/5][1500/1583]        Loss_D: 1.0892  Loss_G: 4.5390  D(x): 0.9515    D(G(z)): 0.5889 / 0.0184
[2/5][1550/1583]        Loss_D: 0.5287  Loss_G: 2.8621  D(x): 0.8721    D(G(z)): 0.2951 / 0.0752
[3/5][0/1583]   Loss_D: 0.7854  Loss_G: 1.2214  D(x): 0.5433    D(G(z)): 0.0867 / 0.3554
[3/5][50/1583]  Loss_D: 0.6804  Loss_G: 1.2959  D(x): 0.6162    D(G(z)): 0.1239 / 0.3190
[3/5][100/1583] Loss_D: 0.5261  Loss_G: 1.9995  D(x): 0.7277    D(G(z)): 0.1501 / 0.1652
[3/5][150/1583] Loss_D: 0.7594  Loss_G: 3.4410  D(x): 0.9255    D(G(z)): 0.4435 / 0.0447
[3/5][200/1583] Loss_D: 0.8398  Loss_G: 3.5978  D(x): 0.9173    D(G(z)): 0.4847 / 0.0361
[3/5][250/1583] Loss_D: 0.7407  Loss_G: 1.7323  D(x): 0.5702    D(G(z)): 0.0892 / 0.2244
[3/5][300/1583] Loss_D: 0.7091  Loss_G: 2.0202  D(x): 0.7278    D(G(z)): 0.2697 / 0.1701
[3/5][350/1583] Loss_D: 0.7962  Loss_G: 2.7491  D(x): 0.8309    D(G(z)): 0.4023 / 0.0916
[3/5][400/1583] Loss_D: 0.8029  Loss_G: 1.4822  D(x): 0.5253    D(G(z)): 0.0333 / 0.2968
[3/5][450/1583] Loss_D: 0.6752  Loss_G: 3.5632  D(x): 0.8454    D(G(z)): 0.3590 / 0.0400
[3/5][500/1583] Loss_D: 0.5680  Loss_G: 2.6532  D(x): 0.7980    D(G(z)): 0.2582 / 0.0916
[3/5][550/1583] Loss_D: 0.9276  Loss_G: 3.2313  D(x): 0.7829    D(G(z)): 0.4334 / 0.0622
[3/5][600/1583] Loss_D: 1.3222  Loss_G: 0.5569  D(x): 0.3844    D(G(z)): 0.1387 / 0.6091
[3/5][650/1583] Loss_D: 0.4432  Loss_G: 2.2500  D(x): 0.7975    D(G(z)): 0.1728 / 0.1364
[3/5][700/1583] Loss_D: 0.5341  Loss_G: 1.7943  D(x): 0.6707    D(G(z)): 0.0861 / 0.2031
[3/5][750/1583] Loss_D: 0.6346  Loss_G: 2.2814  D(x): 0.6629    D(G(z)): 0.1548 / 0.1370
[3/5][800/1583] Loss_D: 0.7141  Loss_G: 3.3863  D(x): 0.8912    D(G(z)): 0.4155 / 0.0493
[3/5][850/1583] Loss_D: 0.7136  Loss_G: 3.3627  D(x): 0.8936    D(G(z)): 0.4107 / 0.0498
[3/5][900/1583] Loss_D: 1.0759  Loss_G: 5.1828  D(x): 0.9595    D(G(z)): 0.5910 / 0.0077
[3/5][950/1583] Loss_D: 0.5146  Loss_G: 2.4333  D(x): 0.7650    D(G(z)): 0.1856 / 0.1130
[3/5][1000/1583]        Loss_D: 1.0340  Loss_G: 1.0067  D(x): 0.4361    D(G(z)): 0.0365 / 0.4171
[3/5][1050/1583]        Loss_D: 0.9027  Loss_G: 1.5421  D(x): 0.6292    D(G(z)): 0.2704 / 0.2686
[3/5][1100/1583]        Loss_D: 0.7330  Loss_G: 1.0975  D(x): 0.5590    D(G(z)): 0.0727 / 0.3822
[3/5][1150/1583]        Loss_D: 0.5788  Loss_G: 2.0302  D(x): 0.7579    D(G(z)): 0.2249 / 0.1619
[3/5][1200/1583]        Loss_D: 0.5699  Loss_G: 2.5464  D(x): 0.6647    D(G(z)): 0.0939 / 0.1174
[3/5][1250/1583]        Loss_D: 0.7822  Loss_G: 3.1897  D(x): 0.8214    D(G(z)): 0.3932 / 0.0600
[3/5][1300/1583]        Loss_D: 0.5714  Loss_G: 3.0932  D(x): 0.8428    D(G(z)): 0.2973 / 0.0599
[3/5][1350/1583]        Loss_D: 0.9439  Loss_G: 1.6137  D(x): 0.4569    D(G(z)): 0.0554 / 0.2486
[3/5][1400/1583]        Loss_D: 0.9508  Loss_G: 0.7944  D(x): 0.4616    D(G(z)): 0.0664 / 0.4980
[3/5][1450/1583]        Loss_D: 0.5473  Loss_G: 2.0525  D(x): 0.7264    D(G(z)): 0.1681 / 0.1564
[3/5][1500/1583]        Loss_D: 1.0865  Loss_G: 4.6609  D(x): 0.9428    D(G(z)): 0.5838 / 0.0134
[3/5][1550/1583]        Loss_D: 0.5538  Loss_G: 2.7744  D(x): 0.8313    D(G(z)): 0.2780 / 0.0804
[4/5][0/1583]   Loss_D: 0.5295  Loss_G: 1.7297  D(x): 0.7407    D(G(z)): 0.1699 / 0.2110
[4/5][50/1583]  Loss_D: 0.8767  Loss_G: 3.9897  D(x): 0.8984    D(G(z)): 0.4903 / 0.0265
[4/5][100/1583] Loss_D: 0.8595  Loss_G: 1.7889  D(x): 0.5127    D(G(z)): 0.0766 / 0.2100
[4/5][150/1583] Loss_D: 0.6354  Loss_G: 1.9135  D(x): 0.6224    D(G(z)): 0.0967 / 0.1885
[4/5][200/1583] Loss_D: 0.6871  Loss_G: 2.6436  D(x): 0.8059    D(G(z)): 0.3389 / 0.0905
[4/5][250/1583] Loss_D: 0.6317  Loss_G: 2.1678  D(x): 0.7464    D(G(z)): 0.2520 / 0.1415
[4/5][300/1583] Loss_D: 0.6106  Loss_G: 2.3683  D(x): 0.7178    D(G(z)): 0.1997 / 0.1194
[4/5][350/1583] Loss_D: 1.3298  Loss_G: 3.7512  D(x): 0.9102    D(G(z)): 0.6503 / 0.0364
[4/5][400/1583] Loss_D: 0.6001  Loss_G: 2.2635  D(x): 0.7665    D(G(z)): 0.2382 / 0.1347
[4/5][450/1583] Loss_D: 0.6210  Loss_G: 1.5374  D(x): 0.6271    D(G(z)): 0.0913 / 0.2607
[4/5][500/1583] Loss_D: 0.6877  Loss_G: 3.1567  D(x): 0.8000    D(G(z)): 0.3260 / 0.0566
[4/5][550/1583] Loss_D: 0.9178  Loss_G: 0.4890  D(x): 0.4992    D(G(z)): 0.0657 / 0.6494
[4/5][600/1583] Loss_D: 0.7856  Loss_G: 3.4874  D(x): 0.7539    D(G(z)): 0.3523 / 0.0413
[4/5][650/1583] Loss_D: 0.4173  Loss_G: 2.4808  D(x): 0.7910    D(G(z)): 0.1448 / 0.1112
[4/5][700/1583] Loss_D: 0.5501  Loss_G: 2.3553  D(x): 0.7228    D(G(z)): 0.1559 / 0.1238
[4/5][750/1583] Loss_D: 0.9603  Loss_G: 1.0889  D(x): 0.4580    D(G(z)): 0.0516 / 0.3987
[4/5][800/1583] Loss_D: 1.0103  Loss_G: 1.0796  D(x): 0.4518    D(G(z)): 0.0700 / 0.3848
[4/5][850/1583] Loss_D: 0.9407  Loss_G: 1.1503  D(x): 0.5190    D(G(z)): 0.1183 / 0.3930
[4/5][900/1583] Loss_D: 0.6042  Loss_G: 3.2949  D(x): 0.8617    D(G(z)): 0.3398 / 0.0485
[4/5][950/1583] Loss_D: 0.8044  Loss_G: 4.0215  D(x): 0.9126    D(G(z)): 0.4697 / 0.0273
[4/5][1000/1583]        Loss_D: 0.5038  Loss_G: 2.0548  D(x): 0.8618    D(G(z)): 0.2604 / 0.1715
[4/5][1050/1583]        Loss_D: 0.5536  Loss_G: 3.5237  D(x): 0.8584    D(G(z)): 0.3012 / 0.0407
[4/5][1100/1583]        Loss_D: 0.6723  Loss_G: 2.6173  D(x): 0.7372    D(G(z)): 0.2591 / 0.1056
[4/5][1150/1583]        Loss_D: 0.5664  Loss_G: 2.2516  D(x): 0.7457    D(G(z)): 0.1978 / 0.1436
[4/5][1200/1583]        Loss_D: 0.4951  Loss_G: 1.9954  D(x): 0.7180    D(G(z)): 0.1165 / 0.1719
[4/5][1250/1583]        Loss_D: 0.8713  Loss_G: 3.4892  D(x): 0.9282    D(G(z)): 0.4974 / 0.0435
[4/5][1300/1583]        Loss_D: 0.4074  Loss_G: 2.8002  D(x): 0.8255    D(G(z)): 0.1672 / 0.0835
[4/5][1350/1583]        Loss_D: 0.4612  Loss_G: 2.9137  D(x): 0.7708    D(G(z)): 0.1513 / 0.0754
[4/5][1400/1583]        Loss_D: 0.5805  Loss_G: 2.3746  D(x): 0.6993    D(G(z)): 0.1554 / 0.1186
[4/5][1450/1583]        Loss_D: 0.7233  Loss_G: 1.2625  D(x): 0.6116    D(G(z)): 0.1500 / 0.3253
[4/5][1500/1583]        Loss_D: 0.6058  Loss_G: 1.9010  D(x): 0.7431    D(G(z)): 0.2253 / 0.1869
[4/5][1550/1583]        Loss_D: 1.0142  Loss_G: 0.7959  D(x): 0.4979    D(G(z)): 0.1639 / 0.4867

Results

Finally, lets check out how we did. Here, we will look at three different results. First, we will see how D and G’s losses changed during training. Second, we will visualize G’s output on the fixed_noise batch for every epoch. And third, we will look at a batch of real data next to a batch of fake data from G.

Loss versus training iteration

Below is a plot of D & G’s losses versus training iterations.

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
../_images/sphx_glr_dcgan_faces_tutorial_002.png

Visualization of G’s progression

Remember how we saved the generator’s output on the fixed_noise batch after every epoch of training. Now, we can visualize the training progression of G with an animation. Press the play button to start the animation.

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
../_images/sphx_glr_dcgan_faces_tutorial_003.png

Real Images vs. Fake Images

Finally, lets take a look at some real images and fake images side by side.

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
../_images/sphx_glr_dcgan_faces_tutorial_004.png

Where to Go Next

We have reached the end of our journey, but there are several places you could go from here. You could:

  • Train for longer to see how good the results get
  • Modify this model to take a different dataset and possibly change the size of the images and the model architecture
  • Check out some other cool GAN projects here
  • Create GANs that generate music

Total running time of the script: ( 28 minutes 32.897 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources