%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
plt.figure(figsize=(16, 16))
for image_n in range(25):
image = mnist.test_set()[image_n][0].reshape((28,28))
prediction = np.argmax(network.predict(image.reshape((1,28*28))))
glimpses = glimpse_function(image.reshape((1,28*28)))
gca = plt.subplot(5, 5, image_n+1)
gca.matshow(image, cmap="gray")
for i, glimpse in enumerate(glimpses):
x, y = glimpse * 14 + 14
gca.add_patch(Rectangle((x - 3.5, y - 3.5), 7, 7, ec="red", fill=None, alpha=0.8))
gca.annotate(str(i+1), xy=(x - 3.5, y - 3.5), color='r', weight='bold', fontsize=8, ha='center', va='center', alpha=0.8)
gca.annotate("Prediction: %d" % prediction, xy=(19, 26), color='r', weight='bold', fontsize=12, ha='center', va='center', alpha=0.8)
plt.show()