.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "advanced/neural_style_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_advanced_neural_style_tutorial.py: Neural Transfer Using PyTorch ============================= **Author**: `Alexis Jacq `_ **Edited by**: `Winston Herring `_ Introduction ------------ This tutorial explains how to implement the `Neural-Style algorithm `__ developed by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge. Neural-Style, or Neural-Transfer, allows you to take an image and reproduce it with a new artistic style. The algorithm takes three images, an input image, a content-image, and a style-image, and changes the input to resemble the content of the content-image and the artistic style of the style-image. .. figure:: /_static/img/neural-style/neuralstyle.png :alt: content1 .. GENERATED FROM PYTHON SOURCE LINES 26-49 Underlying Principle -------------------- The principle is simple: we define two distances, one for the content (:math:`D_C`) and one for the style (:math:`D_S`). :math:`D_C` measures how different the content is between two images while :math:`D_S` measures how different the style is between two images. Then, we take a third image, the input, and transform it to minimize both its content-distance with the content-image and its style-distance with the style-image. Now we can import the necessary packages and begin the neural transfer. Importing Packages and Selecting a Device ----------------------------------------- Below is a list of the packages needed to implement the neural transfer. - ``torch``, ``torch.nn``, ``numpy`` (indispensables packages for neural networks with PyTorch) - ``torch.optim`` (efficient gradient descents) - ``PIL``, ``PIL.Image``, ``matplotlib.pyplot`` (load and display images) - ``torchvision.transforms`` (transform PIL images into tensors) - ``torchvision.models`` (train or load pretrained models) - ``copy`` (to deep copy the models; system package) .. GENERATED FROM PYTHON SOURCE LINES 49-64 .. code-block:: default import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from PIL import Image import matplotlib.pyplot as plt import torchvision.transforms as transforms from torchvision.models import vgg19, VGG19_Weights import copy .. GENERATED FROM PYTHON SOURCE LINES 65-71 Next, we need to choose which device to run the network on and import the content and style images. Running the neural transfer algorithm on large images takes longer and will go much faster when running on a GPU. We can use ``torch.cuda.is_available()`` to detect if there is a GPU available. Next, we set the ``torch.device`` for use throughout the tutorial. Also the ``.to(device)`` method is used to move tensors or modules to a desired device. .. GENERATED FROM PYTHON SOURCE LINES 71-75 .. code-block:: default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_default_device(device) .. GENERATED FROM PYTHON SOURCE LINES 76-96 Loading the Images ------------------ Now we will import the style and content images. The original PIL images have values between 0 and 255, but when transformed into torch tensors, their values are converted to be between 0 and 1. The images also need to be resized to have the same dimensions. An important detail to note is that neural networks from the torch library are trained with tensor values ranging from 0 to 1. If you try to feed the networks with 0 to 255 tensor images, then the activated feature maps will be unable to sense the intended content and style. However, pretrained networks from the Caffe library are trained with 0 to 255 tensor images. .. note:: Here are links to download the images required to run the tutorial: `picasso.jpg `__ and `dancing.jpg `__. Download these two images and add them to a directory with name ``images`` in your current working directory. .. GENERATED FROM PYTHON SOURCE LINES 96-119 .. code-block:: default # desired size of the output image imsize = 512 if torch.cuda.is_available() else 128 # use small size if no GPU loader = transforms.Compose([ transforms.Resize(imsize), # scale imported image transforms.ToTensor()]) # transform it into a torch tensor def image_loader(image_name): image = Image.open(image_name) # fake batch dimension required to fit network's input dimensions image = loader(image).unsqueeze(0) return image.to(device, torch.float) style_img = image_loader("./data/images/neural-style/picasso.jpg") content_img = image_loader("./data/images/neural-style/dancing.jpg") assert style_img.size() == content_img.size(), \ "we need to import style and content images of the same size" .. GENERATED FROM PYTHON SOURCE LINES 120-124 Now, let's create a function that displays an image by reconverting a copy of it to PIL format and displaying the copy using ``plt.imshow``. We will try displaying the content and style images to ensure they were imported correctly. .. GENERATED FROM PYTHON SOURCE LINES 124-145 .. code-block:: default unloader = transforms.ToPILImage() # reconvert into PIL image plt.ion() def imshow(tensor, title=None): image = tensor.cpu().clone() # we clone the tensor to not do changes on it image = image.squeeze(0) # remove the fake batch dimension image = unloader(image) plt.imshow(image) if title is not None: plt.title(title) plt.pause(0.001) # pause a bit so that plots are updated plt.figure() imshow(style_img, title='Style Image') plt.figure() imshow(content_img, title='Content Image') .. rst-class:: sphx-glr-horizontal * .. image-sg:: /advanced/images/sphx_glr_neural_style_tutorial_001.png :alt: Style Image :srcset: /advanced/images/sphx_glr_neural_style_tutorial_001.png :class: sphx-glr-multi-img * .. image-sg:: /advanced/images/sphx_glr_neural_style_tutorial_002.png :alt: Content Image :srcset: /advanced/images/sphx_glr_neural_style_tutorial_002.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 146-170 Loss Functions -------------- Content Loss ~~~~~~~~~~~~ The content loss is a function that represents a weighted version of the content distance for an individual layer. The function takes the feature maps :math:`F_{XL}` of a layer :math:`L` in a network processing input :math:`X` and returns the weighted content distance :math:`w_{CL}.D_C^L(X,C)` between the image :math:`X` and the content image :math:`C`. The feature maps of the content image(:math:`F_{CL}`) must be known by the function in order to calculate the content distance. We implement this function as a torch module with a constructor that takes :math:`F_{CL}` as an input. The distance :math:`\|F_{XL} - F_{CL}\|^2` is the mean square error between the two sets of feature maps, and can be computed using ``nn.MSELoss``. We will add this content loss module directly after the convolution layer(s) that are being used to compute the content distance. This way each time the network is fed an input image the content losses will be computed at the desired layers and because of auto grad, all the gradients will be computed. Now, in order to make the content loss layer transparent we must define a ``forward`` method that computes the content loss and then returns the layer’s input. The computed loss is saved as a parameter of the module. .. GENERATED FROM PYTHON SOURCE LINES 170-185 .. code-block:: default class ContentLoss(nn.Module): def __init__(self, target,): super(ContentLoss, self).__init__() # we 'detach' the target content from the tree used # to dynamically compute the gradient: this is a stated value, # not a variable. Otherwise the forward method of the criterion # will throw an error. self.target = target.detach() def forward(self, input): self.loss = F.mse_loss(input, self.target) return input .. GENERATED FROM PYTHON SOURCE LINES 186-192 .. note:: **Important detail**: although this module is named ``ContentLoss``, it is not a true PyTorch Loss function. If you want to define your content loss as a PyTorch Loss function, you have to create a PyTorch autograd function to recompute/implement the gradient manually in the ``backward`` method. .. GENERATED FROM PYTHON SOURCE LINES 194-216 Style Loss ~~~~~~~~~~ The style loss module is implemented similarly to the content loss module. It will act as a transparent layer in a network that computes the style loss of that layer. In order to calculate the style loss, we need to compute the gram matrix :math:`G_{XL}`. A gram matrix is the result of multiplying a given matrix by its transposed matrix. In this application the given matrix is a reshaped version of the feature maps :math:`F_{XL}` of a layer :math:`L`. :math:`F_{XL}` is reshaped to form :math:`\hat{F}_{XL}`, a :math:`K`\ x\ :math:`N` matrix, where :math:`K` is the number of feature maps at layer :math:`L` and :math:`N` is the length of any vectorized feature map :math:`F_{XL}^k`. For example, the first line of :math:`\hat{F}_{XL}` corresponds to the first vectorized feature map :math:`F_{XL}^1`. Finally, the gram matrix must be normalized by dividing each element by the total number of elements in the matrix. This normalization is to counteract the fact that :math:`\hat{F}_{XL}` matrices with a large :math:`N` dimension yield larger values in the Gram matrix. These larger values will cause the first layers (before pooling layers) to have a larger impact during the gradient descent. Style features tend to be in the deeper layers of the network so this normalization step is crucial. .. GENERATED FROM PYTHON SOURCE LINES 216-231 .. code-block:: default def gram_matrix(input): a, b, c, d = input.size() # a=batch size(=1) # b=number of feature maps # (c,d)=dimensions of a f. map (N=c*d) features = input.view(a * b, c * d) # resize F_XL into \hat F_XL G = torch.mm(features, features.t()) # compute the gram product # we 'normalize' the values of the gram matrix # by dividing by the number of element in each feature maps. return G.div(a * b * c * d) .. GENERATED FROM PYTHON SOURCE LINES 232-236 Now the style loss module looks almost exactly like the content loss module. The style distance is also computed using the mean square error between :math:`G_{XL}` and :math:`G_{SL}`. .. GENERATED FROM PYTHON SOURCE LINES 236-249 .. code-block:: default class StyleLoss(nn.Module): def __init__(self, target_feature): super(StyleLoss, self).__init__() self.target = gram_matrix(target_feature).detach() def forward(self, input): G = gram_matrix(input) self.loss = F.mse_loss(G, self.target) return input .. GENERATED FROM PYTHON SOURCE LINES 250-264 Importing the Model ------------------- Now we need to import a pretrained neural network. We will use a 19 layer VGG network like the one used in the paper. PyTorch’s implementation of VGG is a module divided into two child ``Sequential`` modules: ``features`` (containing convolution and pooling layers), and ``classifier`` (containing fully connected layers). We will use the ``features`` module because we need the output of the individual convolution layers to measure content and style loss. Some layers have different behavior during training than evaluation, so we must set the network to evaluation mode using ``.eval()``. .. GENERATED FROM PYTHON SOURCE LINES 264-269 .. code-block:: default cnn = vgg19(weights=VGG19_Weights.DEFAULT).features.eval() .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth 0%| | 0.00/548M [00:00`__, we will use L-BFGS algorithm to run our gradient descent. Unlike training a network, we want to train the input image in order to minimize the content/style losses. We will create a PyTorch L-BFGS optimizer ``optim.LBFGS`` and pass our image to it as the tensor to optimize. .. GENERATED FROM PYTHON SOURCE LINES 394-401 .. code-block:: default def get_input_optimizer(input_img): # this line to show that input is a parameter that requires a gradient optimizer = optim.LBFGS([input_img]) return optimizer .. GENERATED FROM PYTHON SOURCE LINES 402-413 Finally, we must define a function that performs the neural transfer. For each iteration of the networks, it is fed an updated input and computes new losses. We will run the ``backward`` methods of each loss module to dynamically compute their gradients. The optimizer requires a “closure” function, which reevaluates the module and returns the loss. We still have one final constraint to address. The network may try to optimize the input with values that exceed the 0 to 1 tensor range for the image. We can address this by correcting the input values to be between 0 to 1 each time the network is run. .. GENERATED FROM PYTHON SOURCE LINES 413-475 .. code-block:: default def run_style_transfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img, num_steps=300, style_weight=1000000, content_weight=1): """Run the style transfer.""" print('Building the style transfer model..') model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img) # We want to optimize the input and not the model parameters so we # update all the requires_grad fields accordingly input_img.requires_grad_(True) # We also put the model in evaluation mode, so that specific layers # such as dropout or batch normalization layers behave correctly. model.eval() model.requires_grad_(False) optimizer = get_input_optimizer(input_img) print('Optimizing..') run = [0] while run[0] <= num_steps: def closure(): # correct the values of updated input image with torch.no_grad(): input_img.clamp_(0, 1) optimizer.zero_grad() model(input_img) style_score = 0 content_score = 0 for sl in style_losses: style_score += sl.loss for cl in content_losses: content_score += cl.loss style_score *= style_weight content_score *= content_weight loss = style_score + content_score loss.backward() run[0] += 1 if run[0] % 50 == 0: print("run {}:".format(run)) print('Style Loss : {:4f} Content Loss: {:4f}'.format( style_score.item(), content_score.item())) print() return style_score + content_score optimizer.step(closure) # a last correction... with torch.no_grad(): input_img.clamp_(0, 1) return input_img .. GENERATED FROM PYTHON SOURCE LINES 476-478 Finally, we can run the algorithm. .. GENERATED FROM PYTHON SOURCE LINES 478-489 .. code-block:: default output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img) plt.figure() imshow(output, title='Output Image') # sphinx_gallery_thumbnail_number = 4 plt.ioff() plt.show() .. image-sg:: /advanced/images/sphx_glr_neural_style_tutorial_004.png :alt: Output Image :srcset: /advanced/images/sphx_glr_neural_style_tutorial_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Building the style transfer model.. Optimizing.. run [50]: Style Loss : 5.217325 Content Loss: 4.085981 run [100]: Style Loss : 1.126348 Content Loss: 3.028276 run [150]: Style Loss : 0.708760 Content Loss: 2.643001 run [200]: Style Loss : 0.476622 Content Loss: 2.491493 run [250]: Style Loss : 0.341410 Content Loss: 2.401098 run [300]: Style Loss : 0.260281 Content Loss: 2.348618 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 38.791 seconds) .. _sphx_glr_download_advanced_neural_style_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: neural_style_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: neural_style_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_