In contrast to the simple book DataAugmentation.ipynb where we explained how to do augmentation, we now show the real value of augmenting the training data in order to reduce overfitting by creating more training examples. This notebook runs a bit longer, it can be controlled via the following variables:
dataSize = 2000
verbose = False #If you don't want to print out something
epochs_no_augmentation = 200
epochs_augmentation = 200
#from load_mnist import load_data_2d
#X,y,PIXELS = load_data_2d('../../data/mnist.pkl.gz')
import cPickle as pickle
import gzip
with gzip.open('mnist_4000.pkl.gz', 'rb') as f:
(X,y) = pickle.load(f)
PIXELS = len(X[0,0,0,:])
X.shape, y.shape, PIXELS
((4000, 1, 28, 28), (4000,), 28)
A small technicality: For doing the transformation the data has to be in the range [-1,1] (due to the limitation of skimage.transform library). We therefore rescale the data to the range [0,256] and dived by 256 later when we do the augmentation.
import numpy as np
#maxs = np.max(X[:,0,:,:],axis=(1,2))
#mins = np.min(X[:,0,:,:],axis=(1,2))
#Xs = np.zeros_like(X)
#for i in range(len(X)):
# Xs[i,0,:,:] = (X[i,0,:,:] - maxs[i])/(maxs[i]-mins[i])
#X = Xs
Xs = (X - np.min(X)) / (np.max(X) - np.min(X))
X = Xs * 256.0
np.min(X),np.max(X)
(0.0, 256.0)
We use the standart CNN network again and do no augmentation in the first round.
from lasagne import layers
from lasagne import nonlinearities
from nolearn.lasagne import NeuralNet
def createNet():
return NeuralNet(
# Geometry of the network
layers=[
('input', layers.InputLayer),
('conv1', layers.Conv2DLayer),
('pool1', layers.MaxPool2DLayer),
('conv2', layers.Conv2DLayer),
('pool2', layers.MaxPool2DLayer),
('hidden4', layers.DenseLayer),
('output', layers.DenseLayer),
],
input_shape=(None, 1, PIXELS, PIXELS), #None in the first axis indicates that the batch size can be set later
conv1_num_filters=32, conv1_filter_size=(3, 3), pool1_pool_size=(2, 2), #pool_size used to be called ds in old versions of lasagne
conv2_num_filters=64, conv2_filter_size=(2, 2), pool2_pool_size=(2, 2),
hidden4_num_units=500,
output_num_units=10, output_nonlinearity=nonlinearities.softmax,
# learning rate parameters
update_learning_rate=0.01,
update_momentum=0.90,
regression=False,
# We only train for 10 epochs
max_epochs=10,
verbose=1,
# Training test-set split
eval_size = 0.2
)
Using gpu device 0: GRID K520
netnoAug = createNet()
netnoAug.max_epochs = epochs_no_augmentation
netnoAug.verbose = verbose
d = netnoAug.fit(X[0:dataSize,:,:,:],y[0:dataSize]);
# Neural Network with 1166086 learnable parameters ## Layer information | # | name | size | |----:|:--------|:---------| | 0 | input | 1x28x28 | | 1 | conv1 | 32x26x26 | | 2 | pool1 | 32x13x13 | | 3 | conv2 | 64x12x12 | | 4 | pool2 | 64x6x6 | | 5 | hidden4 | 500 | | 6 | output | 10 | epoch train loss valid loss train/val valid acc dur ------- ------------ ------------ ----------- ----------- ----- 1 1.53637 0.78706 1.95204 0.77354 0.45s 2 0.57104 0.38058 1.50044 0.88384 0.45s 3 0.27632 0.33141 0.83375 0.89751 0.45s 4 0.16645 0.28451 0.58506 0.90532 0.45s 5 0.10429 0.26261 0.39715 0.91118 0.45s 6 0.06875 0.26661 0.25786 0.91118 0.45s 7 0.04346 0.26795 0.16220 0.91900 0.45s 8 0.02662 0.26630 0.09996 0.92290 0.45s 9 0.01667 0.27969 0.05961 0.92486 0.45s 10 0.01165 0.31809 0.03663 0.91704 0.45s 11 0.00833 0.34051 0.02446 0.92095 0.45s 12 0.00661 0.33055 0.01999 0.91900 0.45s 13 0.00526 0.33988 0.01547 0.92095 0.45s 14 0.00415 0.34305 0.01209 0.92290 0.45s 15 0.00311 0.34454 0.00902 0.92095 0.45s 16 0.00265 0.35059 0.00755 0.92290 0.45s 17 0.00237 0.35284 0.00671 0.92290 0.45s 18 0.00213 0.35659 0.00598 0.92290 0.45s 19 0.00194 0.35902 0.00542 0.92290 0.45s 20 0.00178 0.36253 0.00492 0.92290 0.45s 21 0.00165 0.36487 0.00452 0.92290 0.45s 22 0.00153 0.36757 0.00416 0.92290 0.45s 23 0.00143 0.37007 0.00385 0.92290 0.45s 24 0.00133 0.37231 0.00358 0.92290 0.45s 25 0.00125 0.37464 0.00335 0.92290 0.45s 26 0.00118 0.37703 0.00313 0.92290 0.45s 27 0.00112 0.37889 0.00295 0.92290 0.45s 28 0.00106 0.38109 0.00277 0.92290 0.45s 29 0.00100 0.38297 0.00262 0.92290 0.45s 30 0.00096 0.38487 0.00248 0.92290 0.45s 31 0.00091 0.38674 0.00236 0.92290 0.45s 32 0.00087 0.38845 0.00224 0.92290 0.45s 33 0.00083 0.39004 0.00214 0.92290 0.45s 34 0.00080 0.39161 0.00204 0.92290 0.45s 35 0.00077 0.39318 0.00195 0.92290 0.45s 36 0.00074 0.39471 0.00187 0.92290 0.45s 37 0.00071 0.39608 0.00179 0.92290 0.45s 38 0.00068 0.39764 0.00172 0.92290 0.45s 39 0.00066 0.39878 0.00166 0.92290 0.45s 40 0.00064 0.40028 0.00159 0.92290 0.45s 41 0.00062 0.40148 0.00153 0.92290 0.45s 42 0.00060 0.40271 0.00148 0.92290 0.45s 43 0.00058 0.40389 0.00143 0.92290 0.45s 44 0.00056 0.40517 0.00138 0.92290 0.45s 45 0.00054 0.40634 0.00134 0.92290 0.45s 46 0.00053 0.40735 0.00129 0.92486 0.45s 47 0.00051 0.40864 0.00125 0.92486 0.45s 48 0.00050 0.40964 0.00121 0.92486 0.45s 49 0.00048 0.41071 0.00118 0.92486 0.45s 50 0.00047 0.41178 0.00114 0.92486 0.45s 51 0.00046 0.41278 0.00111 0.92486 0.45s 52 0.00045 0.41384 0.00108 0.92486 0.45s 53 0.00044 0.41476 0.00105 0.92486 0.45s 54 0.00042 0.41572 0.00102 0.92486 0.45s 55 0.00041 0.41664 0.00100 0.92486 0.45s 56 0.00040 0.41763 0.00097 0.92486 0.45s 57 0.00040 0.41841 0.00095 0.92486 0.45s 58 0.00039 0.41942 0.00092 0.92486 0.45s 59 0.00038 0.42021 0.00090 0.92486 0.45s 60 0.00037 0.42112 0.00088 0.92486 0.45s 61 0.00036 0.42197 0.00086 0.92486 0.45s 62 0.00035 0.42276 0.00084 0.92486 0.45s 63 0.00035 0.42360 0.00082 0.92486 0.45s 64 0.00034 0.42440 0.00080 0.92486 0.45s 65 0.00033 0.42510 0.00078 0.92486 0.45s 66 0.00033 0.42596 0.00077 0.92486 0.45s 67 0.00032 0.42673 0.00075 0.92486 0.45s 68 0.00031 0.42746 0.00073 0.92486 0.45s 69 0.00031 0.42824 0.00072 0.92486 0.45s 70 0.00030 0.42897 0.00071 0.92486 0.45s 71 0.00030 0.42969 0.00069 0.92486 0.45s 72 0.00029 0.43031 0.00068 0.92486 0.45s 73 0.00029 0.43100 0.00066 0.92486 0.45s 74 0.00028 0.43177 0.00065 0.92486 0.45s 75 0.00028 0.43235 0.00064 0.92486 0.45s 76 0.00027 0.43302 0.00063 0.92486 0.45s 77 0.00027 0.43367 0.00062 0.92486 0.45s 78 0.00026 0.43431 0.00061 0.92486 0.45s 79 0.00026 0.43489 0.00060 0.92486 0.45s 80 0.00025 0.43551 0.00058 0.92486 0.45s 81 0.00025 0.43618 0.00057 0.92486 0.45s 82 0.00025 0.43674 0.00057 0.92486 0.45s 83 0.00024 0.43732 0.00056 0.92486 0.45s 84 0.00024 0.43794 0.00055 0.92486 0.45s 85 0.00024 0.43849 0.00054 0.92486 0.45s 86 0.00023 0.43913 0.00053 0.92486 0.45s 87 0.00023 0.43969 0.00052 0.92486 0.45s 88 0.00023 0.44018 0.00051 0.92486 0.45s 89 0.00022 0.44079 0.00050 0.92486 0.45s 90 0.00022 0.44131 0.00050 0.92486 0.45s 91 0.00022 0.44187 0.00049 0.92486 0.45s 92 0.00021 0.44237 0.00048 0.92681 0.45s 93 0.00021 0.44288 0.00047 0.92681 0.45s 94 0.00021 0.44343 0.00047 0.92681 0.45s 95 0.00020 0.44389 0.00046 0.92681 0.45s 96 0.00020 0.44440 0.00045 0.92681 0.45s 97 0.00020 0.44489 0.00045 0.92681 0.45s 98 0.00020 0.44542 0.00044 0.92681 0.45s 99 0.00019 0.44591 0.00044 0.92681 0.45s 100 0.00019 0.44636 0.00043 0.92681 0.45s 101 0.00019 0.44683 0.00042 0.92681 0.45s 102 0.00019 0.44734 0.00042 0.92681 0.45s 103 0.00018 0.44776 0.00041 0.92681 0.45s 104 0.00018 0.44827 0.00041 0.92681 0.45s 105 0.00018 0.44869 0.00040 0.92681 0.45s 106 0.00018 0.44917 0.00040 0.92681 0.45s 107 0.00018 0.44962 0.00039 0.92681 0.45s 108 0.00017 0.45002 0.00039 0.92681 0.45s 109 0.00017 0.45048 0.00038 0.92681 0.45s 110 0.00017 0.45095 0.00038 0.92681 0.45s 111 0.00017 0.45139 0.00037 0.92681 0.45s 112 0.00017 0.45177 0.00037 0.92681 0.45s 113 0.00016 0.45221 0.00036 0.92681 0.45s 114 0.00016 0.45261 0.00036 0.92681 0.45s 115 0.00016 0.45298 0.00036 0.92681 0.45s 116 0.00016 0.45342 0.00035 0.92681 0.45s 117 0.00016 0.45382 0.00035 0.92681 0.45s 118 0.00016 0.45423 0.00034 0.92681 0.45s 119 0.00015 0.45463 0.00034 0.92681 0.45s 120 0.00015 0.45501 0.00034 0.92681 0.45s 121 0.00015 0.45540 0.00033 0.92681 0.45s 122 0.00015 0.45575 0.00033 0.92681 0.45s 123 0.00015 0.45614 0.00032 0.92681 0.45s 124 0.00015 0.45651 0.00032 0.92681 0.45s 125 0.00015 0.45687 0.00032 0.92681 0.45s 126 0.00014 0.45726 0.00031 0.92681 0.45s 127 0.00014 0.45759 0.00031 0.92681 0.45s 128 0.00014 0.45798 0.00031 0.92681 0.45s 129 0.00014 0.45832 0.00030 0.92681 0.45s 130 0.00014 0.45870 0.00030 0.92681 0.45s 131 0.00014 0.45904 0.00030 0.92681 0.45s 132 0.00014 0.45938 0.00030 0.92681 0.45s 133 0.00013 0.45971 0.00029 0.92681 0.45s 134 0.00013 0.46006 0.00029 0.92681 0.45s 135 0.00013 0.46043 0.00029 0.92681 0.45s 136 0.00013 0.46073 0.00028 0.92681 0.45s 137 0.00013 0.46109 0.00028 0.92681 0.45s 138 0.00013 0.46141 0.00028 0.92681 0.45s 139 0.00013 0.46177 0.00028 0.92681 0.45s 140 0.00013 0.46209 0.00027 0.92681 0.45s 141 0.00013 0.46240 0.00027 0.92681 0.45s 142 0.00012 0.46271 0.00027 0.92681 0.45s 143 0.00012 0.46301 0.00027 0.92681 0.45s 144 0.00012 0.46334 0.00026 0.92681 0.45s 145 0.00012 0.46364 0.00026 0.92681 0.45s 146 0.00012 0.46396 0.00026 0.92681 0.45s 147 0.00012 0.46428 0.00026 0.92681 0.45s 148 0.00012 0.46460 0.00025 0.92681 0.45s 149 0.00012 0.46493 0.00025 0.92681 0.45s 150 0.00012 0.46519 0.00025 0.92681 0.45s 151 0.00011 0.46550 0.00025 0.92681 0.45s 152 0.00011 0.46582 0.00024 0.92681 0.45s 153 0.00011 0.46614 0.00024 0.92681 0.45s 154 0.00011 0.46641 0.00024 0.92681 0.45s 155 0.00011 0.46671 0.00024 0.92681 0.45s 156 0.00011 0.46703 0.00024 0.92681 0.45s 157 0.00011 0.46733 0.00023 0.92681 0.45s 158 0.00011 0.46762 0.00023 0.92681 0.45s 159 0.00011 0.46790 0.00023 0.92681 0.45s 160 0.00011 0.46820 0.00023 0.92681 0.45s 161 0.00011 0.46851 0.00023 0.92681 0.45s 162 0.00011 0.46878 0.00022 0.92681 0.45s 163 0.00010 0.46905 0.00022 0.92681 0.45s 164 0.00010 0.46936 0.00022 0.92681 0.45s 165 0.00010 0.46966 0.00022 0.92681 0.45s 166 0.00010 0.46992 0.00022 0.92681 0.45s 167 0.00010 0.47020 0.00022 0.92681 0.45s 168 0.00010 0.47046 0.00021 0.92681 0.45s 169 0.00010 0.47074 0.00021 0.92681 0.45s 170 0.00010 0.47102 0.00021 0.92681 0.45s 171 0.00010 0.47129 0.00021 0.92681 0.45s 172 0.00010 0.47157 0.00021 0.92681 0.45s 173 0.00010 0.47186 0.00021 0.92681 0.45s 174 0.00010 0.47211 0.00020 0.92681 0.45s 175 0.00010 0.47238 0.00020 0.92681 0.45s 176 0.00010 0.47265 0.00020 0.92681 0.45s 177 0.00009 0.47290 0.00020 0.92681 0.45s 178 0.00009 0.47316 0.00020 0.92681 0.45s 179 0.00009 0.47343 0.00020 0.92681 0.45s 180 0.00009 0.47367 0.00020 0.92681 0.45s 181 0.00009 0.47396 0.00019 0.92681 0.45s 182 0.00009 0.47421 0.00019 0.92681 0.45s 183 0.00009 0.47446 0.00019 0.92681 0.45s 184 0.00009 0.47475 0.00019 0.92681 0.45s 185 0.00009 0.47496 0.00019 0.92681 0.45s 186 0.00009 0.47524 0.00019 0.92681 0.45s 187 0.00009 0.47548 0.00019 0.92681 0.45s 188 0.00009 0.47569 0.00018 0.92681 0.45s 189 0.00009 0.47598 0.00018 0.92681 0.45s 190 0.00009 0.47623 0.00018 0.92681 0.45s 191 0.00009 0.47645 0.00018 0.92681 0.45s 192 0.00009 0.47670 0.00018 0.92681 0.45s 193 0.00009 0.47696 0.00018 0.92681 0.45s 194 0.00008 0.47719 0.00018 0.92681 0.45s 195 0.00008 0.47741 0.00018 0.92681 0.45s 196 0.00008 0.47764 0.00017 0.92681 0.45s 197 0.00008 0.47789 0.00017 0.92681 0.45s 198 0.00008 0.47812 0.00017 0.92681 0.45s 199 0.00008 0.47836 0.00017 0.92681 0.45s 200 0.00008 0.47857 0.00017 0.92681 0.45s
/usr/local/lib/python2.7/dist-packages/lasagne/init.py:86: UserWarning: The uniform initializer no longer uses Glorot et al.'s approach to determine the bounds, but defaults to the range (-0.01, 0.01) instead. Please use the new GlorotUniform initializer to get the old behavior. GlorotUniform is now the default for all layers. warnings.warn("The uniform initializer no longer uses Glorot et al.'s " /usr/local/lib/python2.7/dist-packages/lasagne/layers/helper.py:69: UserWarning: get_all_layers() has been changed to return layers in topological order. The former implementation is still available as get_all_layers_old(), but will be removed before the first release of Lasagne. To ignore this warning, use `warnings.filterwarnings('ignore', '.*topo.*')`. warnings.warn("get_all_layers() has been changed to return layers in "
%matplotlib inline
import pandas as pd
dfNoAug = pd.DataFrame(netnoAug.train_history_)
dfNoAug[['train_loss','valid_loss','valid_accuracy']].plot(title='No Augmentation', ylim=(0,1))
<matplotlib.axes._subplots.AxesSubplot at 0x7f9547fb8fd0>
We see a clear overfitting. The training loss drops towards 0 but after approximately 10 epochs the validation error rises again. This is a clear sign of overfitting.
This is not astonishing since we only have $1000 \cdot 80=800$ training examples for each number. We are now trying to create new training data on the fly by performing manipulations.
Below we define small random rotation, scalings, and translations, which keep the labels intact.
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as imgplot
import numpy as np
from skimage import transform as tf
rots = np.asarray((-20,-10,-5,5,10,20)) / (360 / (2.0 * np.pi))
dists = (-1,1)
def manipulateTrainingData(Xb):
retX = np.zeros((Xb.shape[0], Xb.shape[1], Xb.shape[2], Xb.shape[3]), dtype='float32')
for i in range(len(Xb)):
dist = dists[np.random.randint(0, len(dists))]
rot = rots[np.random.randint(0, len(rots))]
scale = np.random.uniform(0.9,1.10)
tform = tf.SimilarityTransform(rotation=rot, translation=dist, scale=scale)
retX[i,0,:,:] = 256.0 * tf.warp(Xb[i,0,:,:]/256.0,tform) # "Float Images" are only allowed to have values between -1 and 1
return retX
Xb = np.copy(Xs[0:100,:,:,:])
Xb = manipulateTrainingData(Xb)
fig = plt.figure(figsize=(10,10))
for i in range(18):
a=fig.add_subplot(6,6,2*i+1,xticks=[], yticks=[])
plt.imshow(-Xs[i,0,:,:], cmap=plt.get_cmap('gray'))
a=fig.add_subplot(6,6,2*i+2,xticks=[], yticks=[])
plt.imshow(-Xb[i,0,:,:], cmap=plt.get_cmap('gray'))
The idea is to create these random transformations on the fly, each time a new minibatch is processes. We overwrite the BatchIterator again as follows:
from nolearn.lasagne import BatchIterator
# Our BatchIterator
class SimpleBatchIterator(BatchIterator):
def transform(self, Xb, yb):
# The 'incomming' and outcomming shape is (10, 1, 28, 28)
Xb, yb = super(SimpleBatchIterator, self).transform(Xb, yb)
return manipulateTrainingData(Xb), yb #<--- Here we do the manipulations of the training set
# Setting the new batch iterator
net1Aug = createNet()
net1Aug.max_epochs = epochs_augmentation
net1Aug.batch_iterator_train = SimpleBatchIterator(256)
net1Aug.verbose = verbose
d = net1Aug.fit(X[0:dataSize,:,:,:],y[0:dataSize])
# Neural Network with 1166086 learnable parameters ## Layer information | # | name | size | |----:|:--------|:---------| | 0 | input | 1x28x28 | | 1 | conv1 | 32x26x26 | | 2 | pool1 | 32x13x13 | | 3 | conv2 | 64x12x12 | | 4 | pool2 | 64x6x6 | | 5 | hidden4 | 500 | | 6 | output | 10 | epoch train loss valid loss train/val valid acc dur ------- ------------ ------------ ----------- ----------- ----- 1 2.30608 1.96445 1.17391 0.35434 0.74s 2 2.01773 1.45139 1.39021 0.56610 0.73s 3 1.81158 1.27584 1.41991 0.54657 0.72s 4 1.73366 1.02107 1.69788 0.71207 0.72s 5 1.47019 0.78629 1.86978 0.77549 0.72s 6 1.45336 0.98191 1.48013 0.69449 0.72s 7 1.21955 0.67311 1.81182 0.82627 0.72s 8 1.02982 0.63660 1.61768 0.80191 0.72s 9 0.90394 0.59653 1.51533 0.80387 0.72s 10 0.76318 0.52048 1.46631 0.83460 0.72s 11 0.95245 0.55810 1.70661 0.83604 0.72s 12 0.67583 0.40352 1.67484 0.89412 0.72s 13 0.59991 0.39262 1.52797 0.87850 0.72s 14 0.54010 0.46575 1.15962 0.88045 0.72s 15 0.48017 0.36385 1.31969 0.89947 0.72s 16 0.45339 0.40844 1.11004 0.87120 0.72s 17 0.42722 0.45479 0.93936 0.87120 0.72s 18 0.44681 0.33334 1.34040 0.89556 0.72s 19 0.36431 0.38223 0.95311 0.89021 0.72s 20 0.34225 0.33378 1.02536 0.91509 0.72s 21 0.32706 0.40196 0.81364 0.89217 0.72s 22 0.32523 0.36846 0.88269 0.89165 0.72s 23 0.29768 0.34789 0.85565 0.90193 0.72s 24 0.31619 0.27120 1.16591 0.91704 0.72s 25 0.26995 0.32971 0.81873 0.90728 0.72s 26 0.25623 0.31245 0.82008 0.91900 0.72s 27 0.24846 0.29205 0.85073 0.92290 0.72s 28 0.24052 0.25757 0.93377 0.93020 0.72s 29 0.23819 0.31916 0.74631 0.91118 0.72s 30 0.26519 0.31718 0.83609 0.91118 0.72s 31 0.23614 0.33063 0.71421 0.91704 0.72s 32 0.21648 0.23260 0.93070 0.93997 0.72s 33 0.21201 0.23971 0.88443 0.94922 0.72s 34 0.19707 0.30706 0.64179 0.92290 0.72s 35 0.18102 0.24087 0.75154 0.94387 0.72s 36 0.18620 0.31421 0.59261 0.92969 0.72s 37 0.17710 0.27894 0.63490 0.93801 0.72s 38 0.18261 0.25835 0.70685 0.92825 0.72s 39 0.18291 0.23863 0.76652 0.94192 0.72s 40 0.15237 0.25257 0.60328 0.94192 0.72s 41 0.16424 0.25040 0.65591 0.93997 0.72s 42 0.14158 0.27577 0.51340 0.93997 0.72s 43 0.15508 0.25922 0.59826 0.92825 0.72s 44 0.15254 0.26591 0.57366 0.93997 0.72s 45 0.15147 0.21222 0.71372 0.93072 0.72s 46 0.13575 0.23328 0.58190 0.93997 0.72s 47 0.12745 0.27445 0.46438 0.93801 0.72s 48 0.15132 0.22999 0.65794 0.94192 0.72s 49 0.14166 0.21087 0.67181 0.94387 0.72s 50 0.14258 0.20946 0.68068 0.93997 0.72s 51 0.13269 0.19121 0.69393 0.96094 0.72s 52 0.12257 0.24960 0.49106 0.94778 0.72s 53 0.10974 0.20756 0.52871 0.94778 0.72s 54 0.11829 0.22622 0.52291 0.94192 0.72s 55 0.09373 0.21211 0.44190 0.94973 0.72s 56 0.11157 0.20830 0.53563 0.94778 0.72s 57 0.12615 0.20985 0.60116 0.94387 0.72s 58 0.12998 0.22371 0.58102 0.94583 0.72s 59 0.11699 0.21953 0.53292 0.94583 0.72s 60 0.11241 0.24368 0.46130 0.93853 0.72s 61 0.09690 0.25064 0.38663 0.91900 0.72s 62 0.10970 0.27174 0.40368 0.94387 0.73s 63 0.10836 0.25536 0.42435 0.94387 0.72s 64 0.08670 0.29101 0.29791 0.93801 0.72s 65 0.10435 0.22637 0.46099 0.94973 0.72s 66 0.09576 0.22660 0.42260 0.93997 0.72s 67 0.10449 0.25241 0.41397 0.92290 0.72s 68 0.09301 0.23700 0.39246 0.93657 0.73s 69 0.09995 0.20781 0.48094 0.94973 0.72s 70 0.08679 0.22616 0.38373 0.93801 0.72s 71 0.07352 0.20197 0.36402 0.94973 0.72s 72 0.08946 0.24246 0.36896 0.93801 0.72s 73 0.10632 0.22271 0.47739 0.94973 0.72s 74 0.09892 0.20696 0.47798 0.96484 0.72s 75 0.09174 0.22134 0.41446 0.94973 0.72s 76 0.09880 0.21335 0.46309 0.95898 0.72s 77 0.07272 0.20315 0.35796 0.94973 0.72s 78 0.08011 0.20045 0.39965 0.96094 0.72s 79 0.06696 0.22964 0.29159 0.94192 0.72s 80 0.07632 0.21774 0.35052 0.95169 0.72s 81 0.07106 0.21810 0.32580 0.95364 0.72s 82 0.07793 0.20865 0.37351 0.95703 0.72s 83 0.08754 0.22292 0.39271 0.95508 0.72s 84 0.08097 0.20301 0.39882 0.95508 0.72s 85 0.06372 0.20511 0.31066 0.96094 0.72s 86 0.08269 0.20218 0.40900 0.94778 0.72s 87 0.07169 0.21900 0.32736 0.94973 0.72s 88 0.06887 0.24861 0.27704 0.94387 0.72s 89 0.07369 0.24066 0.30621 0.93801 0.72s 90 0.06643 0.25289 0.26269 0.94778 0.72s 91 0.06757 0.21474 0.31465 0.94778 0.72s 92 0.06401 0.21528 0.29735 0.94778 0.72s 93 0.06325 0.23734 0.26647 0.95169 0.72s 94 0.06340 0.21917 0.28926 0.95169 0.72s 95 0.06583 0.24759 0.26589 0.94973 0.72s 96 0.04598 0.24539 0.18735 0.95559 0.72s 97 0.05027 0.22390 0.22453 0.95169 0.72s 98 0.04552 0.20541 0.22159 0.94778 0.72s 99 0.05005 0.21804 0.22956 0.95364 0.72s 100 0.06691 0.22222 0.30111 0.94387 0.72s 101 0.05496 0.23093 0.23798 0.94778 0.72s 102 0.04820 0.21522 0.22397 0.95169 0.72s 103 0.04832 0.19910 0.24271 0.96145 0.72s 104 0.03160 0.20591 0.15344 0.95559 0.72s 105 0.05887 0.20632 0.28536 0.94778 0.72s 106 0.05616 0.21137 0.26569 0.94973 0.72s 107 0.04073 0.21493 0.18951 0.94778 0.72s 108 0.06785 0.20879 0.32499 0.95169 0.72s 109 0.05048 0.23698 0.21302 0.94387 0.72s 110 0.05156 0.23853 0.21615 0.95364 0.72s 111 0.04024 0.26057 0.15444 0.94973 0.72s 112 0.05155 0.22082 0.23347 0.94973 0.72s 113 0.04654 0.25237 0.18439 0.94973 0.72s 114 0.06328 0.20434 0.30970 0.95559 0.72s 115 0.05588 0.21127 0.26450 0.95508 0.72s 116 0.05785 0.18995 0.30454 0.94973 0.72s 117 0.06222 0.20780 0.29942 0.94973 0.72s 118 0.05558 0.18275 0.30413 0.96680 0.72s 119 0.03973 0.19907 0.19956 0.94973 0.72s 120 0.05540 0.18747 0.29551 0.95364 0.72s 121 0.04638 0.20612 0.22503 0.94778 0.71s 122 0.04415 0.20566 0.21469 0.94583 0.71s 123 0.04519 0.20639 0.21897 0.95559 0.71s 124 0.04949 0.25339 0.19529 0.94778 0.71s 125 0.05532 0.18476 0.29941 0.95364 0.71s 126 0.03447 0.19252 0.17903 0.95559 0.71s 127 0.03236 0.26154 0.12374 0.92876 0.72s 128 0.04006 0.19450 0.20597 0.94973 0.72s 129 0.03090 0.21151 0.14611 0.94973 0.72s 130 0.04255 0.20254 0.21007 0.96875 0.72s 131 0.02505 0.19387 0.12922 0.94778 0.72s 132 0.03497 0.21474 0.16284 0.94973 0.72s 133 0.03600 0.24204 0.14872 0.94973 0.71s 134 0.03874 0.19071 0.20314 0.95364 0.72s 135 0.03378 0.21231 0.15910 0.95364 0.71s 136 0.02697 0.19273 0.13994 0.95950 0.72s 137 0.02808 0.20360 0.13793 0.95950 0.71s 138 0.02811 0.20095 0.13991 0.95364 0.72s 139 0.03336 0.19490 0.17118 0.94973 0.72s 140 0.02341 0.21864 0.10708 0.95169 0.72s 141 0.03916 0.19284 0.20309 0.95755 0.72s 142 0.02796 0.24585 0.11374 0.94973 0.72s 143 0.06422 0.19839 0.32370 0.95169 0.72s 144 0.04844 0.25822 0.18758 0.93801 0.72s 145 0.04496 0.20422 0.22015 0.95950 0.72s 146 0.04760 0.21612 0.22025 0.94973 0.72s 147 0.05047 0.21854 0.23094 0.95559 0.72s 148 0.04045 0.23034 0.17562 0.93267 0.72s 149 0.04586 0.22260 0.20600 0.95508 0.72s 150 0.03409 0.18889 0.18046 0.97070 0.72s 151 0.05785 0.19566 0.29566 0.95950 0.72s 152 0.04593 0.16746 0.27428 0.96875 0.72s 153 0.04215 0.22356 0.18853 0.95755 0.72s 154 0.04807 0.19326 0.24875 0.95364 0.72s 155 0.03139 0.18962 0.16553 0.94973 0.72s 156 0.06613 0.17826 0.37098 0.94778 0.72s 157 0.03327 0.21482 0.15488 0.95169 0.72s 158 0.03324 0.19550 0.17001 0.94778 0.72s 159 0.04177 0.18471 0.22611 0.95559 0.72s 160 0.03364 0.18932 0.17770 0.95559 0.72s 161 0.02742 0.18673 0.14686 0.94778 0.72s 162 0.04930 0.19918 0.24749 0.94973 0.72s 163 0.03850 0.16095 0.23923 0.95950 0.72s 164 0.02251 0.18049 0.12470 0.95559 0.72s 165 0.03117 0.15960 0.19532 0.95755 0.72s 166 0.02274 0.17259 0.13178 0.95559 0.72s 167 0.03512 0.18443 0.19044 0.96145 0.72s 168 0.03758 0.17130 0.21937 0.94973 0.72s 169 0.04059 0.17570 0.23102 0.95755 0.72s 170 0.03016 0.17855 0.16889 0.95559 0.72s 171 0.03312 0.18094 0.18306 0.96145 0.72s 172 0.02527 0.14686 0.17210 0.95755 0.72s 173 0.01993 0.16443 0.12120 0.95950 0.72s 174 0.03338 0.18310 0.18232 0.95364 0.72s 175 0.02506 0.20041 0.12505 0.95559 0.72s 176 0.02704 0.21332 0.12674 0.95755 0.72s 177 0.02523 0.20248 0.12463 0.95169 0.72s 178 0.02128 0.22457 0.09477 0.95755 0.72s 179 0.02239 0.17743 0.12618 0.96145 0.72s 180 0.01845 0.17626 0.10469 0.96145 0.72s 181 0.02053 0.18840 0.10898 0.95755 0.72s 182 0.02245 0.16886 0.13295 0.95950 0.72s 183 0.02036 0.20210 0.10075 0.95950 0.72s 184 0.02149 0.18718 0.11482 0.96340 0.72s 185 0.02248 0.18314 0.12277 0.95755 0.72s 186 0.02224 0.18898 0.11767 0.95950 0.72s 187 0.03557 0.21344 0.16667 0.95950 0.72s 188 0.02222 0.19399 0.11454 0.95364 0.72s 189 0.01299 0.18580 0.06989 0.95364 0.72s 190 0.02548 0.18661 0.13652 0.95755 0.72s 191 0.02466 0.27673 0.08911 0.95559 0.72s 192 0.03247 0.19393 0.16742 0.96145 0.72s 193 0.02148 0.23772 0.09037 0.95559 0.72s 194 0.03202 0.22083 0.14500 0.95169 0.72s 195 0.01782 0.21139 0.08429 0.95559 0.72s 196 0.01745 0.23783 0.07337 0.96145 0.72s 197 0.01290 0.22734 0.05675 0.96145 0.72s 198 0.01653 0.19786 0.08352 0.96340 0.72s 199 0.01897 0.19145 0.09909 0.96145 0.72s 200 0.01986 0.22074 0.08996 0.95755 0.72s
Let's have a look at the training history to see, if we do overfitting:
dfAug = pd.DataFrame(net1Aug.train_history_)
dfAug[['train_loss','valid_loss','valid_accuracy']].plot(title='With Augmentation')
<matplotlib.axes._subplots.AxesSubplot at 0x7f9543b94990>
We still see some overfitting, but not so severe as if we would not do augmentation.
Just a small detour, if you have R installed you can make nice plots (using e.g. ggplot)
%load_ext rpy2.ipython
%Rpush dfAug
%Rpush dfNoAug
%%R
library(ggplot2)
ggplot() + aes(x=epoch, colour='Loss') +
geom_line(data=dfAug, aes(y = train_loss, colour='Training'), size=2) +
geom_line(data=dfAug, aes(y = valid_loss, colour='Validation'), size=2) +
geom_line(data=dfNoAug, aes(y = train_loss, colour='Training'), size=2) +
geom_line(data=dfNoAug, aes(y = valid_loss, colour='Validation'), size=2) +
xlab('Epochs') + ylab('Loss') +
ylim(c(0,0.75))
%%R
library(ggplot2)
ggplot() + aes(x=epoch, colour='Loss') +
geom_line(data=dfAug, aes(y = valid_accuracy, colour='Augmented'), size=2) +
geom_line(data=dfNoAug, aes(y = valid_accuracy, colour='Not Augmented'), size=2) +
xlab('Epochs') + ylab('Accuracy on validation set') +
ylim(c(0.75,1))
We see that the augmentation of the training data not only mitigates overfitting but also results in a better accuracy (approx. 95% compared to 92.25%). By using all data and doing more clever augmentations the test accuracy can reach over 99.65% see http://arxiv.org/abs/1003.0358