#!/usr/bin/env python # coding: utf-8 # ## Generative Adversarial Networks # In this notebook we'll see how to build a very popular type of deep learning model called a generative adversarial network, or GAN for short. Interest in GANs has exploded in the last few years thanks to their unique ability to generate complex human-like results. GANs have been used for image synthesis, style transfer, music generation, and many other novel applications. We'll build a particular type of GAN called the Deep Convolutional GAN (DCGAN) from scratch and train it to generate MNIST images. # # Much of the code used in this notebook is adapted from the excellent [Keras-GAN](https://github.com/eriklindernoren/Keras-GAN) library created by Erik Linder-Noren. # # Let's get started by importing a few libraries and loading up MNIST into an array. # In[1]: import time import matplotlib.pyplot as plt import numpy as np from keras.datasets import mnist # In[2]: (X, _), (_, _) = mnist.load_data() X = X / 127.5 - 1. X = np.expand_dims(X, axis=3) X.shape # Note that we've added another dimension for the channel which is usually for the color value (red, green, blue). Since these images are greyscale it's not strictly necessary but Keras expects an input array in this shape. # # The first thing we need to do is define the generator. This is the model that will learn how to produce new images. The architecture looks a lot like a normal convolutional network except in reverse. We start with an input that is essentially random noise and mold it into an image through a series of convolution blocks. All of these layers should be familiar to anyone that's worked with convolutional networks before with the possible exception of the "upsampling" layer, which essentially doubles the size of the tensor along two axes by repeating rows and columns. # In[3]: from keras.models import Model from keras.layers import Activation, BatchNormalization, Conv2D from keras.layers import Dense, Input, Reshape, UpSampling2D def get_generator(noise_shape): i = Input(noise_shape) x = Dense(128 * 7 * 7)(i) x = Activation("relu")(x) x = Reshape((7, 7, 128))(x) x = UpSampling2D()(x) x = Conv2D(128, kernel_size=3, padding="same")(x) x = BatchNormalization(momentum=0.8)(x) x = Activation("relu")(x) x = UpSampling2D()(x) x = Conv2D(64, kernel_size=3, padding="same")(x) x = BatchNormalization(momentum=0.8)(x) x = Activation("relu")(x) x = Conv2D(1, kernel_size=3, padding="same")(x) o = Activation("tanh")(x) return Model(i, o) # It helps to see the summary output for the model and follow along as the shape changes. One interesting thing to note is how many of the parameters in the model are in that initial dense layer. # In[4]: g = get_generator(noise_shape=(100,)) g.summary() # Next up is the discriminator. This model will be tasked with looking at an image and deciding if it's really an MNIST image or if it's a fake. The discriminator looks a lot more like a standard convolutional network. It takes a tensor of an image as it's input and runs it through a series of convolution blocks, decreasing in size until it outputs a single activation representing the probability of the tensor being a real MNIST image. # In[5]: from keras.layers import Dropout, Flatten, LeakyReLU, ZeroPadding2D def get_discriminator(img_shape): i = Input(img_shape) x = Conv2D(32, kernel_size=3, strides=2, padding="same")(i) x = LeakyReLU(alpha=0.2)(x) x = Dropout(0.25)(x) x = Conv2D(64, kernel_size=3, strides=2, padding="same")(i) x = ZeroPadding2D(padding=((0, 1), (0, 1)))(x) x = BatchNormalization(momentum=0.8)(x) x = LeakyReLU(alpha=0.2)(x) x = Dropout(0.25)(x) x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x) x = BatchNormalization(momentum=0.8)(x) x = LeakyReLU(alpha=0.2)(x) x = Dropout(0.25)(x) x = Conv2D(256, kernel_size=3, strides=1, padding="same")(x) x = BatchNormalization(momentum=0.8)(x) x = LeakyReLU(alpha=0.2)(x) x = Dropout(0.25)(x) x = Flatten()(x) o = Dense(1)(x) o = Activation("sigmoid")(o) return Model(i, o) # In[6]: d = get_discriminator(img_shape=(28, 28, 1)) d.summary() # With both components defined, we can now bring it all together and create the GAN model. It's worth spending some time examining this part closely as it can be confusing at first. Notice that we first compile the distriminator by itself. Then we define the generator and feed its output to the discriminator, but only after setting the discriminator so that training is disabled. This is because we don't want the discriminator to update its weights during this part, only the generator. Finally, a third model is created combining the original input to the generator and the output of the discriminator. Don't worry if you find this confusing, it takes some time to fully comprehend. # In[7]: from keras.optimizers import Adam def DCGAN(): optimizer = Adam(lr=0.0002, beta_1=0.5) discriminator = get_discriminator(img_shape=(28, 28, 1)) discriminator.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"]) generator = get_generator(noise_shape=(100,)) z = Input(shape=(100,)) g_out = generator(z) discriminator.trainable = False d_out = discriminator(g_out) combined = Model(z, d_out) combined.compile(loss="binary_crossentropy", optimizer=optimizer) return generator, discriminator, combined # In[8]: generator, discriminator, combined = DCGAN() combined.summary() # Create a function to save the generator's output so we can visually see the results. We'll use this in the training loop below. # In[12]: def save_image(generator, batch): r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, 100)) gen_images = generator.predict(noise) gen_images = 0.5 * gen_images + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap="gray") axs[i, j].axis('off') cnt += 1 fig.savefig("mnist_%d.png" % (batch)) plt.close() # The last challenge with building a GAN is training the whole thing. With something this complex we can't just call the standard fit function because there are multiple steps involed. Let's roll our own and see how the algorithm works. Start by generating some random noise and feeding it through the generator to produce a batch of "fake" images. Next, select a batch of real images from the training data. Now train the distriminator on both the fake image batch and the real image batch, setting labels appropriately (0 for fake and 1 for real). Finally, train the generator by feeding the "combined" model a batch of noise with the "real" label. # # You may find this last step confusing. Remember what this "combined" model is doing. It runs the input (random noise) through the generator, then runs that output through the discriminator (which is frozen), and gets a loss. What this ends of doing in essence is saying "train the generator to minimize the discriminator's ability to discern this output from a real image". This has the effect, over time, of causing the generator to produce realistic images! # In[13]: def train(generator, discriminator, combined, X, batch_size, n_batches): t0 = time.time() for batch in range(n_batches + 1): t1 = time.time() noise = np.random.normal(0, 1, (batch_size, 100)) # Create fake images fake_images = generator.predict(noise) fake_labels = np.zeros((batch_size, 1)) # Select real images idx = np.random.randint(0, X.shape[0], batch_size) real_images = X[idx] real_labels = np.ones((batch_size, 1)) # Train the discriminator d_loss_real = discriminator.train_on_batch(real_images, real_labels) d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # Train the generator noise = np.random.normal(0, 1, (batch_size, 100)) g_loss = combined.train_on_batch(noise, real_labels) t2 = time.time() # Report progress if batch % 100 == 0: print("Batch %d/%d [Batch time: %.2f s, Total: %.2f m] [D loss: %f, Acc: %.2f%%] [G loss: %f]" % (batch, (n_batches), (t2 - t1), ((t2 - t0) / 60), d_loss[0], 100 * d_loss[1], g_loss)) if batch % 500 == 0: save_image(generator, batch) # Run the training loop for a while. We can't read too much into the losses reported by this process because the models are competing with each other. They both end up being pretty stable over the training period but it doesn't really capture what's going on because both models are getting better at their task by roughly equal amounts. If either model went to zero loss that would likely indicate there's a problem that needs fixed. # In[14]: train(generator, discriminator, combined, X=X, batch_size=32, n_batches=3000) # Take a look at the generator's output at various stages of training. In the beginning it's clear that the model is producing noise, but after just a few hundred batches it's already making progress. After a couple thousand batches the images look quite realistic. # In[15]: from IPython.display import Image, display names = ["mnist_0.png", "mnist_500.png", "mnist_1000.png", "mnist_1500.png", "mnist_2000.png", "mnist_2500.png", "mnist_3000.png"] for name in names: display(Image(filename=name)) # In[ ]: