In this notebook we'll see how to build a very popular type of deep learning model called a generative adversarial network, or GAN for short. Interest in GANs has exploded in the last few years thanks to their unique ability to generate complex human-like results. GANs have been used for image synthesis, style transfer, music generation, and many other novel applications. We'll build a particular type of GAN called the Deep Convolutional GAN (DCGAN) from scratch and train it to generate MNIST images.
Much of the code used in this notebook is adapted from the excellent Keras-GAN library created by Erik Linder-Noren.
Let's get started by importing a few libraries and loading up MNIST into an array.
import time
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import mnist
Using TensorFlow backend.
(X, _), (_, _) = mnist.load_data()
X = X / 127.5 - 1.
X = np.expand_dims(X, axis=3)
X.shape
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz 11493376/11490434 [==============================] - 1s 0us/step
(60000, 28, 28, 1)
Note that we've added another dimension for the channel which is usually for the color value (red, green, blue). Since these images are greyscale it's not strictly necessary but Keras expects an input array in this shape.
The first thing we need to do is define the generator. This is the model that will learn how to produce new images. The architecture looks a lot like a normal convolutional network except in reverse. We start with an input that is essentially random noise and mold it into an image through a series of convolution blocks. All of these layers should be familiar to anyone that's worked with convolutional networks before with the possible exception of the "upsampling" layer, which essentially doubles the size of the tensor along two axes by repeating rows and columns.
from keras.models import Model
from keras.layers import Activation, BatchNormalization, Conv2D
from keras.layers import Dense, Input, Reshape, UpSampling2D
def get_generator(noise_shape):
i = Input(noise_shape)
x = Dense(128 * 7 * 7)(i)
x = Activation("relu")(x)
x = Reshape((7, 7, 128))(x)
x = UpSampling2D()(x)
x = Conv2D(128, kernel_size=3, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation("relu")(x)
x = UpSampling2D()(x)
x = Conv2D(64, kernel_size=3, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation("relu")(x)
x = Conv2D(1, kernel_size=3, padding="same")(x)
o = Activation("tanh")(x)
return Model(i, o)
It helps to see the summary output for the model and follow along as the shape changes. One interesting thing to note is how many of the parameters in the model are in that initial dense layer.
g = get_generator(noise_shape=(100,))
g.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 100) 0 _________________________________________________________________ dense_1 (Dense) (None, 6272) 633472 _________________________________________________________________ activation_1 (Activation) (None, 6272) 0 _________________________________________________________________ reshape_1 (Reshape) (None, 7, 7, 128) 0 _________________________________________________________________ up_sampling2d_1 (UpSampling2 (None, 14, 14, 128) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 14, 14, 128) 147584 _________________________________________________________________ batch_normalization_1 (Batch (None, 14, 14, 128) 512 _________________________________________________________________ activation_2 (Activation) (None, 14, 14, 128) 0 _________________________________________________________________ up_sampling2d_2 (UpSampling2 (None, 28, 28, 128) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 28, 28, 64) 73792 _________________________________________________________________ batch_normalization_2 (Batch (None, 28, 28, 64) 256 _________________________________________________________________ activation_3 (Activation) (None, 28, 28, 64) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 28, 28, 1) 577 _________________________________________________________________ activation_4 (Activation) (None, 28, 28, 1) 0 ================================================================= Total params: 856,193 Trainable params: 855,809 Non-trainable params: 384 _________________________________________________________________
Next up is the discriminator. This model will be tasked with looking at an image and deciding if it's really an MNIST image or if it's a fake. The discriminator looks a lot more like a standard convolutional network. It takes a tensor of an image as it's input and runs it through a series of convolution blocks, decreasing in size until it outputs a single activation representing the probability of the tensor being a real MNIST image.
from keras.layers import Dropout, Flatten, LeakyReLU, ZeroPadding2D
def get_discriminator(img_shape):
i = Input(img_shape)
x = Conv2D(32, kernel_size=3, strides=2, padding="same")(i)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.25)(x)
x = Conv2D(64, kernel_size=3, strides=2, padding="same")(i)
x = ZeroPadding2D(padding=((0, 1), (0, 1)))(x)
x = BatchNormalization(momentum=0.8)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.25)(x)
x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.25)(x)
x = Conv2D(256, kernel_size=3, strides=1, padding="same")(x)
x = BatchNormalization(momentum=0.8)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
o = Dense(1)(x)
o = Activation("sigmoid")(o)
return Model(i, o)
d = get_discriminator(img_shape=(28, 28, 1))
d.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) (None, 28, 28, 1) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 14, 14, 64) 640 _________________________________________________________________ zero_padding2d_1 (ZeroPaddin (None, 15, 15, 64) 0 _________________________________________________________________ batch_normalization_3 (Batch (None, 15, 15, 64) 256 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 15, 15, 64) 0 _________________________________________________________________ dropout_2 (Dropout) (None, 15, 15, 64) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 8, 8, 128) 73856 _________________________________________________________________ batch_normalization_4 (Batch (None, 8, 8, 128) 512 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 8, 8, 128) 0 _________________________________________________________________ dropout_3 (Dropout) (None, 8, 8, 128) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 8, 8, 256) 295168 _________________________________________________________________ batch_normalization_5 (Batch (None, 8, 8, 256) 1024 _________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 8, 8, 256) 0 _________________________________________________________________ dropout_4 (Dropout) (None, 8, 8, 256) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 16384) 0 _________________________________________________________________ dense_2 (Dense) (None, 1) 16385 _________________________________________________________________ activation_5 (Activation) (None, 1) 0 ================================================================= Total params: 387,841 Trainable params: 386,945 Non-trainable params: 896 _________________________________________________________________
With both components defined, we can now bring it all together and create the GAN model. It's worth spending some time examining this part closely as it can be confusing at first. Notice that we first compile the distriminator by itself. Then we define the generator and feed its output to the discriminator, but only after setting the discriminator so that training is disabled. This is because we don't want the discriminator to update its weights during this part, only the generator. Finally, a third model is created combining the original input to the generator and the output of the discriminator. Don't worry if you find this confusing, it takes some time to fully comprehend.
from keras.optimizers import Adam
def DCGAN():
optimizer = Adam(lr=0.0002, beta_1=0.5)
discriminator = get_discriminator(img_shape=(28, 28, 1))
discriminator.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
generator = get_generator(noise_shape=(100,))
z = Input(shape=(100,))
g_out = generator(z)
discriminator.trainable = False
d_out = discriminator(g_out)
combined = Model(z, d_out)
combined.compile(loss="binary_crossentropy", optimizer=optimizer)
return generator, discriminator, combined
generator, discriminator, combined = DCGAN()
combined.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) (None, 100) 0 _________________________________________________________________ model_4 (Model) (None, 28, 28, 1) 856193 _________________________________________________________________ model_3 (Model) (None, 1) 387841 ================================================================= Total params: 1,244,034 Trainable params: 855,809 Non-trainable params: 388,225 _________________________________________________________________
Create a function to save the generator's output so we can visually see the results. We'll use this in the training loop below.
def save_image(generator, batch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
gen_images = generator.predict(noise)
gen_images = 0.5 * gen_images + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap="gray")
axs[i, j].axis('off')
cnt += 1
fig.savefig("mnist_%d.png" % (batch))
plt.close()
The last challenge with building a GAN is training the whole thing. With something this complex we can't just call the standard fit function because there are multiple steps involed. Let's roll our own and see how the algorithm works. Start by generating some random noise and feeding it through the generator to produce a batch of "fake" images. Next, select a batch of real images from the training data. Now train the distriminator on both the fake image batch and the real image batch, setting labels appropriately (0 for fake and 1 for real). Finally, train the generator by feeding the "combined" model a batch of noise with the "real" label.
You may find this last step confusing. Remember what this "combined" model is doing. It runs the input (random noise) through the generator, then runs that output through the discriminator (which is frozen), and gets a loss. What this ends of doing in essence is saying "train the generator to minimize the discriminator's ability to discern this output from a real image". This has the effect, over time, of causing the generator to produce realistic images!
def train(generator, discriminator, combined, X, batch_size, n_batches):
t0 = time.time()
for batch in range(n_batches + 1):
t1 = time.time()
noise = np.random.normal(0, 1, (batch_size, 100))
# Create fake images
fake_images = generator.predict(noise)
fake_labels = np.zeros((batch_size, 1))
# Select real images
idx = np.random.randint(0, X.shape[0], batch_size)
real_images = X[idx]
real_labels = np.ones((batch_size, 1))
# Train the discriminator
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train the generator
noise = np.random.normal(0, 1, (batch_size, 100))
g_loss = combined.train_on_batch(noise, real_labels)
t2 = time.time()
# Report progress
if batch % 100 == 0:
print("Batch %d/%d [Batch time: %.2f s, Total: %.2f m] [D loss: %f, Acc: %.2f%%] [G loss: %f]" %
(batch, (n_batches), (t2 - t1), ((t2 - t0) / 60), d_loss[0], 100 * d_loss[1], g_loss))
if batch % 500 == 0:
save_image(generator, batch)
Run the training loop for a while. We can't read too much into the losses reported by this process because the models are competing with each other. They both end up being pretty stable over the training period but it doesn't really capture what's going on because both models are getting better at their task by roughly equal amounts. If either model went to zero loss that would likely indicate there's a problem that needs fixed.
train(generator, discriminator, combined, X=X, batch_size=32, n_batches=3000)
Batch 0/3000 [Batch time: 0.03 s, Total: 0.00 m] [D loss: 0.687554, Acc: 62.50%] [G loss: 1.088619] Batch 100/3000 [Batch time: 0.03 s, Total: 0.05 m] [D loss: 0.937827, Acc: 50.00%] [G loss: 1.534353] Batch 200/3000 [Batch time: 0.04 s, Total: 0.10 m] [D loss: 0.890954, Acc: 46.88%] [G loss: 1.568295] Batch 300/3000 [Batch time: 0.03 s, Total: 0.16 m] [D loss: 0.757705, Acc: 48.44%] [G loss: 1.378710] Batch 400/3000 [Batch time: 0.03 s, Total: 0.21 m] [D loss: 0.752167, Acc: 53.12%] [G loss: 1.317879] Batch 500/3000 [Batch time: 0.03 s, Total: 0.26 m] [D loss: 0.792607, Acc: 46.88%] [G loss: 1.668158] Batch 600/3000 [Batch time: 0.03 s, Total: 0.31 m] [D loss: 0.681811, Acc: 65.62%] [G loss: 1.534778] Batch 700/3000 [Batch time: 0.03 s, Total: 0.36 m] [D loss: 0.787874, Acc: 45.31%] [G loss: 1.265980] Batch 800/3000 [Batch time: 0.03 s, Total: 0.41 m] [D loss: 0.720734, Acc: 51.56%] [G loss: 1.346431] Batch 900/3000 [Batch time: 0.03 s, Total: 0.46 m] [D loss: 0.816940, Acc: 48.44%] [G loss: 1.155806] Batch 1000/3000 [Batch time: 0.03 s, Total: 0.51 m] [D loss: 0.793951, Acc: 42.19%] [G loss: 1.316577] Batch 1100/3000 [Batch time: 0.03 s, Total: 0.56 m] [D loss: 0.710668, Acc: 54.69%] [G loss: 1.433707] Batch 1200/3000 [Batch time: 0.03 s, Total: 0.61 m] [D loss: 0.822841, Acc: 45.31%] [G loss: 1.055463] Batch 1300/3000 [Batch time: 0.03 s, Total: 0.66 m] [D loss: 0.572396, Acc: 71.88%] [G loss: 1.093200] Batch 1400/3000 [Batch time: 0.03 s, Total: 0.71 m] [D loss: 0.753564, Acc: 50.00%] [G loss: 1.097862] Batch 1500/3000 [Batch time: 0.03 s, Total: 0.76 m] [D loss: 0.617892, Acc: 67.19%] [G loss: 0.939282] Batch 1600/3000 [Batch time: 0.03 s, Total: 0.81 m] [D loss: 0.647099, Acc: 59.38%] [G loss: 1.145073] Batch 1700/3000 [Batch time: 0.03 s, Total: 0.86 m] [D loss: 0.686766, Acc: 59.38%] [G loss: 1.009623] Batch 1800/3000 [Batch time: 0.03 s, Total: 0.91 m] [D loss: 0.757178, Acc: 43.75%] [G loss: 1.162026] Batch 1900/3000 [Batch time: 0.03 s, Total: 0.95 m] [D loss: 0.626452, Acc: 62.50%] [G loss: 1.149147] Batch 2000/3000 [Batch time: 0.03 s, Total: 1.00 m] [D loss: 0.706682, Acc: 54.69%] [G loss: 1.052020] Batch 2100/3000 [Batch time: 0.03 s, Total: 1.05 m] [D loss: 0.783831, Acc: 57.81%] [G loss: 1.271681] Batch 2200/3000 [Batch time: 0.03 s, Total: 1.10 m] [D loss: 0.630528, Acc: 65.62%] [G loss: 1.128464] Batch 2300/3000 [Batch time: 0.02 s, Total: 1.15 m] [D loss: 0.693636, Acc: 57.81%] [G loss: 1.398235] Batch 2400/3000 [Batch time: 0.03 s, Total: 1.19 m] [D loss: 0.735144, Acc: 51.56%] [G loss: 1.137753] Batch 2500/3000 [Batch time: 0.03 s, Total: 1.24 m] [D loss: 0.599575, Acc: 62.50%] [G loss: 0.984566] Batch 2600/3000 [Batch time: 0.03 s, Total: 1.29 m] [D loss: 0.774844, Acc: 51.56%] [G loss: 1.256624] Batch 2700/3000 [Batch time: 0.03 s, Total: 1.34 m] [D loss: 0.708373, Acc: 51.56%] [G loss: 1.276029] Batch 2800/3000 [Batch time: 0.03 s, Total: 1.39 m] [D loss: 0.641630, Acc: 60.94%] [G loss: 1.084028] Batch 2900/3000 [Batch time: 0.03 s, Total: 1.44 m] [D loss: 0.612454, Acc: 65.62%] [G loss: 1.232962] Batch 3000/3000 [Batch time: 0.03 s, Total: 1.49 m] [D loss: 0.706665, Acc: 53.12%] [G loss: 1.106143]
Take a look at the generator's output at various stages of training. In the beginning it's clear that the model is producing noise, but after just a few hundred batches it's already making progress. After a couple thousand batches the images look quite realistic.
from IPython.display import Image, display
names = ["mnist_0.png",
"mnist_500.png",
"mnist_1000.png",
"mnist_1500.png",
"mnist_2000.png",
"mnist_2500.png",
"mnist_3000.png"]
for name in names:
display(Image(filename=name))