import tensorflow as tf
import numpy as np
import os.path as op
import os
import shutil
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
data_dir = op.expanduser("~/data/mnist")
mnist = read_data_sets(data_dir, one_hot=True)
logs_dir = '/tmp/tensorflow_logs'
Extracting /home/ogrisel/data/mnist/train-images-idx3-ubyte.gz Extracting /home/ogrisel/data/mnist/train-labels-idx1-ubyte.gz Extracting /home/ogrisel/data/mnist/t10k-images-idx3-ubyte.gz Extracting /home/ogrisel/data/mnist/t10k-labels-idx1-ubyte.gz
def vec_normalize(vec):
vec_norm = tf.sqrt(tf.reduce_sum(tf.square(vec)))
return vec / (vec_norm + 1e-7)
tf.reset_default_graph()
sess = tf.Session()
dtype = tf.float32
learning_rate = tf.Variable(tf.constant(0.001, dtype=dtype))
with tf.name_scope('input'):
x = tf.placeholder(dtype=dtype, shape=[None, 784], name='x-input')
y = tf.placeholder(dtype=dtype, shape=[None, 10], name='y-input')
with tf.name_scope('variables'):
W = tf.Variable(tf.truncated_normal(shape=(784, 10), stddev=0.1,
dtype=dtype),
name='W')
tf.histogram_summary('weights', W)
b = tf.Variable(tf.zeros(shape=(10,), dtype=dtype), name='b')
tf.histogram_summary('biases', b)
slow_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))
fast_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))
dir_similarity = tf.matmul(tf.reshape(slow_direction, [1, -1]),
tf.reshape(fast_direction, [-1, 1]))[0, 0]
tf.scalar_summary('dir_similarity', dir_similarity)
with tf.name_scope('model'):
preactivations = tf.matmul(x, W) + b
tf.histogram_summary('preactivations', preactivations)
y_pred = tf.nn.softmax(preactivations)
tf.histogram_summary('predicted_probabilities', y_pred)
with tf.name_scope('loss'):
cross_entropies = tf.nn.softmax_cross_entropy_with_logits(preactivations, y)
cross_entropy = tf.reduce_mean(cross_entropies, name='cross_entropy')
tf.scalar_summary('cross_entropy', cross_entropy)
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
correct_prediction = tf.equal(tf.argmax(y, 1),
tf.argmax(y_pred, 1))
with tf.name_scope('correct_prediction'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype))
tf.scalar_summary('accuracy', accuracy)
with tf.name_scope('gradient_directions'):
[gW, gb] = tf.gradients(cross_entropy, [W, b])
gW_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)))
g_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)) + tf.reduce_sum(tf.square(gb)))
tf.scalar_summary('gradient norm', g_norm)
gW_normed = tf.reshape(gW / (gW_norm + 1e-7), [-1])
with tf.name_scope('updates'):
W_update = W.assign_add(-learning_rate * gW)
b_update = b.assign_add(-learning_rate * gb)
slow_rate = 0.05
new_slow_dir = slow_rate * gW_normed + (1 - slow_rate) * slow_direction
slow_dir_update = slow_direction.assign(vec_normalize(new_slow_dir))
slow_dir_reset = slow_direction.assign(gW_normed)
fast_rate = 0.5
new_fast_dir = fast_rate * gW_normed + (1 - fast_rate) * fast_direction
fast_dir_update = fast_direction.assign(vec_normalize(new_fast_dir))
fast_dir_reset = fast_direction.assign(gW_normed)
lr_up = learning_rate.assign(2 * learning_rate)
lr_down = learning_rate.assign(0.1 * learning_rate)
def data_dict(train=True, batch_size=128):
"""Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
if train:
xs, ys = mnist.train.next_batch(batch_size)
else:
xs, ys = mnist.test.images, mnist.test.labels
return {x: xs.astype(np.float32), y: ys.astype(np.float32)}
sess.run(tf.unpack(cross_entropies, num=128), feed_dict=data_dict(train=True))
[3.7324717, 1.8400075, 3.6116686, 1.7056865, 2.0967669, 2.3198204, 2.0171783, 1.7580715, 2.3690248, 1.6259246, 3.3400445, 2.0483375, 1.8097279, 2.1330061, 1.8711376, 2.6461897, 3.4301701, 1.9341406, 1.4128025, 3.1168551, 3.7157907, 2.9077392, 3.325181, 0.89536273, 4.3979654, 0.92203724, 2.2007186, 2.2937737, 2.1817045, 1.8752966, 1.6373912, 2.0365462, 1.6608343, 2.6484132, 3.4957781, 1.9901035, 1.9084624, 2.6680474, 2.0449464, 3.0129635, 2.3355575, 2.6466174, 2.5199924, 1.9306296, 2.47153, 2.5479219, 2.2662649, 2.6136937, 1.7118273, 1.3672709, 3.4149623, 2.152194, 2.3145103, 2.0982614, 1.5726066, 5.6496277, 1.0098101, 3.9320879, 1.1875807, 1.6345842, 4.9190865, 2.1639643, 3.1059659, 2.1172271, 1.7091054, 1.7676189, 2.1179478, 3.7244837, 2.3459547, 4.1443148, 2.1852176, 3.9915304, 1.9832186, 2.7600195, 1.882431, 3.7910852, 0.71117669, 2.0551419, 4.1546483, 3.7213643, 1.8376471, 2.0418112, 1.8615329, 3.6794746, 1.9143579, 2.3036268, 2.2760646, 2.079437, 1.5919703, 2.2989087, 1.7845988, 3.588326, 3.0593004, 4.4371295, 4.3214245, 3.2552154, 1.5259778, 1.9014853, 2.7919619, 2.8022759, 2.0109711, 2.3015294, 2.1909297, 2.7841287, 3.3960891, 2.7723904, 2.6843922, 2.5696578, 1.873759, 4.0515051, 2.7884703, 3.1664705, 2.413384, 2.5950832, 1.4605098, 1.8590325, 3.5861123, 1.8300626, 3.8167338, 2.6522479, 4.1535492, 2.3777215, 1.7563944, 2.1654656, 2.0148871, 2.0356688, 1.4784111, 4.0939426]
def cosine_similarities(x):
x = vec_normalize(x)
return tf.matmul(x, tf.transpose(x))
# flat_gW = tf.reshape(gW, [-1])
[gWs, gbs] = tf.gradients(cross_entropies, [W, b])
sims = cosine_similarities(gWs)
# sess.run(tf.initialize_all_variables())
sess.run(sims, feed_dict=data_dict(train=False)).shape
(784, 784)
128 ** 2
16384
summaries = tf.merge_all_summaries()
shutil.rmtree(logs_dir)
train_writer = tf.train.SummaryWriter(logs_dir + '/train', sess.graph)
test_writer = tf.train.SummaryWriter(logs_dir + '/test')
sess.run(tf.initialize_all_variables())
last_lr_change = 0
cool_down = 100
for i in range(10000):
if i % 100 == 0:
# Evaluate on test set
test_summaries, test_acc, test_dir_similarity, lr = sess.run(
[summaries, accuracy, dir_similarity, learning_rate],
feed_dict=data_dict(train=False))
test_writer.add_summary(test_summaries, i)
print("Accuracy on test: %0.3f, gdir similarity: %0.3f, lr: %f"
% (test_acc, test_dir_similarity, lr))
if lr < 1e-10:
print('Converged!')
break
else:
# Evaluate on train mini_batch
train_summaries, _, _, _, _, train_dir_similarity = sess.run(
[summaries, W_update, b_update, slow_dir_update,
fast_dir_update, dir_similarity],
feed_dict=data_dict(train=True))
train_writer.add_summary(train_summaries, i)
# if i - last_lr_change > cool_down:
# if train_dir_similarity > 0.5:
# print("up")
# sess.run([lr_up, slow_dir_reset, fast_dir_reset],
# feed_dict=data_dict(train=True))
# last_lr_change = i
# cool_down = 1000
# elif train_dir_similarity < 0:
# print('down')
# sess.run([lr_down, slow_dir_reset, fast_dir_reset],
# feed_dict=data_dict(train=True))
# last_lr_change = i
# cool_down = 1000
Accuracy on test: 0.111, gdir similarity: 0.000, lr: 0.001000 Accuracy on test: 0.138, gdir similarity: 0.960, lr: 0.001000 Accuracy on test: 0.166, gdir similarity: 0.956, lr: 0.001000 Accuracy on test: 0.216, gdir similarity: 0.938, lr: 0.001000 Accuracy on test: 0.273, gdir similarity: 0.907, lr: 0.001000 Accuracy on test: 0.329, gdir similarity: 0.945, lr: 0.001000 Accuracy on test: 0.379, gdir similarity: 0.930, lr: 0.001000 Accuracy on test: 0.421, gdir similarity: 0.949, lr: 0.001000 Accuracy on test: 0.458, gdir similarity: 0.935, lr: 0.001000 Accuracy on test: 0.491, gdir similarity: 0.948, lr: 0.001000 Accuracy on test: 0.524, gdir similarity: 0.904, lr: 0.001000
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-94-0186f57f8d61> in <module>() 29 fast_dir_update, dir_similarity], 30 feed_dict=data_dict(train=True)) ---> 31 train_writer.add_summary(train_summaries, i) 32 # if i - last_lr_change > cool_down: 33 # if train_dir_similarity > 0.5: /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/tensorflow/python/training/summary_io.py in add_summary(self, summary, global_step) 133 if isinstance(summary, bytes): 134 summ = summary_pb2.Summary() --> 135 summ.ParseFromString(summary) 136 summary = summ 137 event = event_pb2.Event(wall_time=time.time(), summary=summary) /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/message.py in ParseFromString(self, serialized) 183 """ 184 self.Clear() --> 185 self.MergeFromString(serialized) 186 187 def SerializeToString(self): /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py in MergeFromString(self, serialized) 1089 length = len(serialized) 1090 try: -> 1091 if self._InternalParse(serialized, 0, length) != length: 1092 # The only reason _InternalParse would return early is if it 1093 # encountered an end-group tag. /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py in InternalParse(self, buffer, pos, end) 1125 pos = new_pos 1126 else: -> 1127 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1128 if field_desc: 1129 self._UpdateOneofState(field_desc) /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py in DecodeRepeatedField(buffer, pos, end, message, field_dict) 610 raise _DecodeError('Truncated message.') 611 # Read sub-message. --> 612 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: 613 # The only reason _InternalParse would return early is if it 614 # encountered an end-group tag. /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py in InternalParse(self, buffer, pos, end) 1125 pos = new_pos 1126 else: -> 1127 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1128 if field_desc: 1129 self._UpdateOneofState(field_desc) /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py in DecodeField(buffer, pos, end, message, field_dict) 631 raise _DecodeError('Truncated message.') 632 # Read sub-message. --> 633 if value._InternalParse(buffer, pos, new_pos) != new_pos: 634 # The only reason _InternalParse would return early is if it encountered 635 # an end-group tag. /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py in InternalParse(self, buffer, pos, end) 1125 pos = new_pos 1126 else: -> 1127 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1128 if field_desc: 1129 self._UpdateOneofState(field_desc) /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py in DecodeField(buffer, pos, end, message, field_dict) 237 else: 238 def DecodeField(buffer, pos, end, message, field_dict): --> 239 (field_dict[key], pos) = decode_value(buffer, pos) 240 if pos > end: 241 del field_dict[key] # Discard corrupt value. /home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py in InnerDecode(buffer, pos) 338 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it 339 # as inf or -inf. To avoid that, we treat it specially. --> 340 if ((double_bytes[7:8] in b'\x7F\xFF') 341 and (double_bytes[6:7] >= b'\xF0') 342 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): KeyboardInterrupt: