#!/usr/bin/env python # coding: utf-8 # # Texture synthesis using deep convolutional networks # # This numerical tour implements the method detailed in the paper of [Gatys et al.](https://arxiv.org/pdf/1505.07376.pdf). The implementation is intended to be as simple as possible, using Pytorch hooks to be applicable to any network (as opposed to the [style transfer implementation](https://github.com/leongatys/PytorchNeuralStyleTransfer/blob/master/NeuralStyleTransfer.ipynb)). # # This tour can be used as a gentle introduction to convolutional networks, where one can use a pre-trained network to perform a non-trivial vision task (involving in particular an optimization using back-propagation through the network). # In[1]: import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as transforms import torchvision.models as models import numpy as np from PIL import Image import matplotlib.pyplot as plt # Uncomment if you want to store to your own google drive. # In[2]: # from google.colab import drive # drive.mount('/content/drive') # Check if CUDA is available. # In[3]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) # ## Image loading and normalization # Load the image $f \in \mathbb{R}^{3 \times n_0 \times n_0}$ where $3$ is the number of input channels. # In[4]: from urllib.request import urlopen import io file_adress = 'https://raw.githubusercontent.com/leongatys/DeepTextures/master/Images/pebbles.jpg' file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/radishes256.o.jpg' file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/olives256.o.jpg' file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/tomatoes256.o.jpg' file_adress = 'http://graphics.stanford.edu/projects/texture/demo/texture_data/original/eero/yellow-peppers256.o.jpg' fd = urlopen(file_adress) image_file = io.BytesIO(fd.read()) f_pil = Image.open(image_file) plt.imshow(f_pil) plt.axis('off'); # Image normalization (to fit network normalization during training) # $$ \forall \ell \in \{0,1,2\}, \quad f[\ell,\cdot,\cdot] \leftarrow (f[\ell,\cdot,\cdot]-m_\ell)/\sigma_\ell $$ # where $m$ and $\sigma$ are the empirical mean and standard deviation of the training set (here [imagenet](http://www.image-net.org/) dataset). # In[5]: n = 256 def normalize(f): preprocess = transforms.Compose([ transforms.Resize(n), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return torch.autograd.Variable(preprocess(f).unsqueeze(0).cuda()) def deprocess(image): return image * torch.Tensor([0.229, 0.224, 0.225]).cuda() + torch.Tensor([0.485, 0.456, 0.406]).cuda() # ## Neural network loading and testing. # Load a pre-trained (on imagenet) neural network architecture. # In[6]: nn_type = 'resnet' nn_type = 'vgg' if nn_type=='vgg': cnn = models.vgg19(pretrained=True) elif nn_type=='resnet': cnn = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True) # no need to store the gradient for param in cnn.parameters(): param.requires_grad = False if torch.cuda.is_available(): cnn.cuda() # Starting from $f_0 = f \in \mathbb{R}^{m_0=3 \times n_0 \times n_0}$ the input image, the "feature" part of the newtork alternates layers of the form # $$ # f_{i+1} \equiv \Psi_i(f_i) # \equiv [ \text{ReLu}( f_{i} \star w_i) ]_{\downarrow s_i} # = \Psi_i \circ \Psi_{i-1} \circ \ldots \circ \Psi_0(f_0). # $$ # Here $f_{i} \in \mathbb{R}^{m_i \times n_i \times n_i}$ has $m_i$ channels and $n_i^2$ pixels (we assume square images for simplicity), $\star w_i[\ell,\ell',\cdot,\cdot]$ are the convolution filters, so that # $$ # \forall 0 \leq \ell < m_{i+1}, \quad # ( f_{i} \star w_i )[ \ell,x,y ] = \sum_{\ell',x',,y'} # w_i[\ell,\ell',x-x',y-y'] f_{i}[\ell',x',y']. # $$ # ReLu is the Rectified Liner Unit non-linearity ReLu$(s)=\max(s,0)$ and is implicitly applied to each element of a tensor. # The operation $[\cdot]_{\downarrow s_i}$ is a downsampling by a factor $s_i \in \{0,2\}$. If $s_i=0$, nothing is done (so that $n_{i+1}=n_i$), but if $s_i=2$, then the number of pixels is reduced by a factor $4$ and $n_{i+1}=n_i/2$. # The most usual sub-sampling operator (when $s_i=2$) is the max-pooling, where # $$ # (A_{\downarrow 2})[\ell,x,y] \equiv \max(A[\ell,2x,2y],A[\ell,2x+1,2y],A[\ell,2x,2y+2],A[\ell,2x+1,2y+1]). # $$ # # In the following, we denote # $$ # \Phi_i(f_0) \equiv f_i # \quad \text{i.e.} \quad # \Phi_i \equiv \Psi_{i-1} \circ \ldots \circ \Psi_0, # $$ # the map from the input image $f_0$ to the output of the ith layer of the network. # Note that for image classification tasks, this "feature" part is followed by a "classiciation" part, which is composed of a few fully connected (i.e. non-convolutive) layers and a final soft-max layer to output a probability vector among the classes of the dataset. During the training phase, a muli-class classification loss is minized by stochastic gradient descent to tune the convolution filters $(w_i)_i$. This part is assumed to already be done, and we do not optimize the weights $(w_i)_i$ in this tour. # We can display the architecture of the network, which has 16 feature layers, with only 5 having $s_i=2$ (pooling). This means that the final size of the image feature are $n_{16}=n_0/2^5$. The number of channel grows like $m_0=3,64, 128, 256, m_{16}=512$. # In[7]: cnn # Function to save activations when applying a network. # In[8]: activation = {} def get_activation(name): def hook(model, input, output): activation[name] = output #.detach() return hook # Create a function to evaluate the network and retrieve a list $I$ of activations of some layers $(\Phi_i(f))_{i \in I}$. # In[9]: # sub-select only a sub-set of output if nn_type=='vgg': I = range(0,37) # all the layers I = [0] # first layer only I = [36] # last layer only I = [0, 4, 9, 18, 27] # first and after pooling it = 0 for i in I: cnn.features[i].register_forward_hook(get_activation(it)) it = it+1 elif nn_type=='resnet': it = 0 cnn.conv1.register_forward_hook(get_activation(it)); it = it+1 cnn.layer1[2].register_forward_hook(get_activation(it)); it = it+1 cnn.layer2[3].register_forward_hook(get_activation(it)); it = it+1 cnn.layer3[5].register_forward_hook(get_activation(it)); it = it+1 cnn.layer4[2].register_forward_hook(get_activation(it)); it = it+1 # In[10]: if nn_type=='vgg': for i in I: print( cnn.features[i] ) # Apply the network, this will save the activations in the variable `activation`. # # In[11]: f = normalize(f_pil) cnn(f); # In[12]: for a in activation: print(activation[a].shape) # Display the first channel of each saved activations, i.e. $(f_i[0,\cdot,\cdot])_{i \in I}$. # In[13]: for a in activation: plt.subplot(2,3,a+1) plt.imshow(activation[a].cpu()[0,0,:,:].squeeze()) # Now display the total activation over the channels by summation, i.e. $(\sum_\ell f_i[\ell,\cdot,\cdot])_{i \in I}$. # In[14]: for a in activation: plt.subplot(2,3,a+1) plt.imshow(torch.sum(activation[a].cpu(), axis=1).squeeze()) # ## Texture synthesis through optimization # # The general idea of statistical texture synthesis (as opposed to "copy-based" methods) is to draw a random noise image and then coerce it to enforce some empirical statistics to match those of the input one. # # The initial idea appears in the early work of [Heeger and Bergen](https://www.cns.nyu.edu/heegerlab/content/publications/Heeger-siggraph95.pdf) which simply uses the histograms over a wavelet transform. This was refined by [Zhu and Mumford](https://link.springer.com/article/10.1023/A:1007925832420), and by [Portilla and Simoncelli](https://www.cns.nyu.edu/pub/eero/portilla99-reprint.pdf), which uses more complex statistical descriptors (in particular higher order moments). The idea of Gatys' method is similar excepted it replaces the linear wavelet transform by a non-linear neural network. # In this neural network texture model, one only makes use of second order morment. # # We denote $C(h) \in \mathbb{R}^{m \times m}$ the empirical covariance of a feature image $h \in \mathbb{R} ^{m \times n \times n}$ defined as # $$ # \forall 0 \leq \ell,\ell' # $(document).ready(function(){ # $('div.prompt').hide(); # }); #