**Interpreting What a Neural Network Has Learned**

Explainable Artificial Intelligence (XAI): Concepts, taxonomies, opportunities and challenges toward responsible AI, Arrieta, et al., Information Fusion, Volume 58, June 2020, Pages 82-115

"Given a certain audience, explainability refers to the details and reasons a model gives to make its functioning clear or easy to understand."

Here we will examine what the hidden units in a convolutional neural network have learned. This is most intuitive if we focus on classification problems involving images.

import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas as pd
import os

from A6mysolution import *

# for regression problem
def rmse(a, b):
return np.sqrt(np.mean((a - b)**2))

# for classification problem
def percent_correct(a, b):
return 100 * np.mean(a == b)

# for classification problem
def confusion_matrix(Y_classes, T):
class_names = np.unique(T)
table = []
for true_class in class_names:
row = []
for Y_class in class_names:
row.append(100 * np.mean(Y_classes[T == true_class] == Y_class))
table.append(row)
conf_matrix = pd.DataFrame(table, index=class_names, columns=class_names)
print(f'Percent Correct is {percent_correct(Y_classes, T)}')
return conf_matrix

def makeImages(nEach):
images = np.zeros((nEach * 2, 1, 20, 20))  # nSamples, nChannels, rows, columns
radii = 3 + np.random.randint(10 - 5, size=(nEach * 2, 1))
centers = np.zeros((nEach * 2, 2))
for i in range(nEach * 2):
centers[i, :] = r + 1 + np.random.randint(18 - 2 * r, size=(1, 2))
x = int(centers[i, 0])
y = int(centers[i, 1])
if i < nEach:
# squares
images[i, 0, x - r:x + r, y + r] = 1.0
images[i, 0, x - r:x + r, y - r] = 1.0
images[i, 0, x - r, y - r:y + r] = 1.0
images[i, 0, x + r, y - r:y + r + 1] = 1.0
else:
# diamonds
images[i, 0, range(x - r, x), range(y, y + r)] = 1.0
images[i, 0, range(x - r, x), range(y, y - r, -1)] = 1.0
images[i, 0, range(x, x + r + 1), range(y + r, y - 1, -1)] = 1.0
images[i, 0, range(x, x + r), range(y - r, y)] = 1.0
# images += np.random.randn(*images.shape) * 0.5
T = np.zeros((nEach * 2, 1))
T[nEach:] = 1
return images, T

nEach = 1000
X, T = makeImages(nEach)
X = X.reshape(X.shape[0], -1)
print(X.shape, T.shape)

Xtest, Ttest = makeImages(nEach)
Xtest = Xtest.reshape(Xtest.shape[0], -1)

plt.plot(T);

(2000, 400) (2000, 1)

plt.imshow(-X[-1, :].reshape(20, 20), cmap='gray')
plt.xticks([])
plt.yticks([])

([], [])
plt.figure(figsize=(10, 3))

for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow(-X[i, :].reshape(20,20), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(2, 10, i + 11)
plt.imshow(-X[-i, :].reshape(20,20), cmap='gray')
plt.xticks([])
plt.yticks([])

nnet, learning_curve = train_for_classification(X, T, hidden_layers=[10],
n_epochs=500, learning_rate=0.01)
plt.plot(learning_curve);

nnet

Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
Y = use(nnet, X)
Ytest = use(nnet, Xtest)
Y.shape

(2000, 2)
plt.subplot(2, 1, 1)
plt.plot(Y)
plt.subplot(2, 1, 2)
plt.plot(Ytest)

[<matplotlib.lines.Line2D at 0x7fa89861d370>,
<matplotlib.lines.Line2D at 0x7fa89861d460>]
plt.plot(np.exp(Y))

[<matplotlib.lines.Line2D at 0x7fa898582b50>,
<matplotlib.lines.Line2D at 0x7fa898582c40>]
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)  # To keep 2-dimensional shape
plt.plot(Y_classes, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();

Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)  # To keep 2-dimensional shape
plt.plot(Y_classes_test, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();

Y.shape

(2000, 2)
confusion_matrix(Y_classes_test, Ttest)

Percent Correct is 99.05000000000001

0.0 1.0
0.0 99.1 0.9
1.0 1.0 99.0
def forward_all_layers(nnet, X):

X = torch.from_numpy(X).float()
Ys = [X]
for layer in nnet:
Ys.append(layer(Ys[-1]))

Ys = [Y.detach().numpy() for Y in Ys]
return Ys

Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])

nnet

Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
len(Y_square)

5
Y_square[0].shape

(10, 400)
Y_square[1].shape

(10, 10)
plt.plot(Y_square[1]);

plt.plot(Y_square[2]);

both = np.vstack((Y_square[2], Y_diamond[2]))

plt.plot(both);

plt.figure(figsize=(15, 3))
for unit in range(10):
plt.subplot(1, 10, unit + 1)
plt.plot(both[:, unit])
plt.tight_layout()

plt.plot(both[:, 9])

[<matplotlib.lines.Line2D at 0x7fa898a4fcd0>]
nnet

Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
nnet[0].parameters()

<generator object Module.parameters at 0x7fa898a1eac0>
list(nnet[0].parameters())

Out[30]:
[Parameter containing:
tensor([[-0.0235,  0.0183, -0.0043,  ..., -0.0283, -0.0472,  0.0493],
[-0.0305,  0.0305,  0.0375,  ...,  0.0076, -0.0254, -0.0216],
[-0.0205, -0.0357, -0.0086,  ..., -0.0464,  0.0402,  0.0307],
...,
[-0.0139,  0.0290,  0.0346,  ..., -0.0098,  0.0007,  0.0092],
[ 0.0226,  0.0153,  0.0103,  ..., -0.0146,  0.0159,  0.0084],
[ 0.0174,  0.0253, -0.0183,  ...,  0.0091,  0.0213,  0.0307]],
Parameter containing:
tensor([-0.8432, -0.7838, -0.5426, -0.7839,  0.6757, -0.6754, -0.7834,  0.7684,
0.8641,  0.5593], requires_grad=True)]
W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W.shape

(10, 400)
In [32]:
W = W.T
W.shape

plt.plot(W);

plt.plot(W[:, 0])

plt.imshow(W[:, 0].reshape(20, 20), cmap='RdYlGn')
plt.colorbar()

plt.figure(figsize=(15, 3))

for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow(W[:, i].reshape(20,20), cmap='RdYlGn')
plt.xticks([])
plt.yticks([])
plt.colorbar()

plt.subplot(2, 10, i + 11)
plt.plot(both[:, i])

X.shape

plt.imshow(X[4,:].reshape(20, 20), cmap='gray')


Let's automate these steps in a function, so we can try different numbers of hidden units and layers.

nnet

Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T
Wout.shape

(10, 2)
def run_again(hiddens):

nnet, learning_curve = train_for_classification(X, T, hidden_layers=hiddens,
n_epochs=1000, learning_rate=0.01)
plt.figure()
plt.plot(learning_curve)

Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
both = np.vstack((Y_square[2], Y_diamond[2]))

W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W = W.T

Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T

plt.figure(figsize=(15, 3))

n_units = hiddens[0]

size = int(np.sqrt(X.shape[1]))

for i in range(n_units):
plt.subplot(2, n_units, i + 1)
plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
plt.colorbar()
plt.xticks([])
plt.yticks([])

plt.subplot(2, n_units, i + 1 + n_units)
plt.plot(both[:, i])
plt.title(f'{Wout[i,0]:.1f},{Wout[i,1]:.1f}')

Y = use(nnet, X)
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
print(confusion_matrix(Y_classes, T))

Ytest = use(nnet, Xtest)
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
print(confusion_matrix(Y_classes_test, Ttest))

run_again([10])

Percent Correct is 100.0
0.0    1.0
0.0  100.0    0.0
1.0    0.0  100.0
Percent Correct is 99.45
0.0   1.0
0.0  99.8   0.2
1.0   0.9  99.1

if os.path.isfile('small_mnist.npz'):
else:
import shlex
import subprocess
cmd = 'curl "https://www.cs.colostate.edu/~anderson/cs545/notebooks/small_mnist.npz" -o "small_mnist.npz"'
subprocess.call(shlex.split(cmd))

X = small_mnist['X']
T = small_mnist['T']

X.shape, T.shape

Reading data from 'small_mnist.npz'.

((1000, 784), (1000, 1))
plt.imshow(-X[0, :].reshape(28, 28), cmap='gray')

<matplotlib.image.AxesImage at 0x7fa87459ccd0>

Randomly partition the data into 80% for training and 20% for testing, using the following code cells.

n_samples = X.shape[0]
n_train = int(n_samples * 0.6)
rows = np.arange(n_samples)
np.random.shuffle(rows)

Xtrain = X[rows[:n_train], :]
Ttrain = T[rows[:n_train], :]
Xtest = X[rows[n_train:], :]
Ttest = T[rows[n_train:], :]

def run_again_mnist(hiddens):

nnet, learning_curve = train_for_classification(Xtrain, Ttrain, hidden_layers=hiddens,
n_epochs=1000, learning_rate=0.01)
plt.figure()
plt.plot(learning_curve)

Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
both = np.vstack((Y_square[2], Y_diamond[2]))

W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W = W.T

Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T

plt.figure(figsize=(15, 15))

n_units = hiddens[0]

size = int(np.sqrt(X.shape[1]))

n_rows = int(np.sqrt(n_units) + 1)
for i in range(n_units):
plt.subplot(n_rows, n_rows, i + 1)
plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
plt.colorbar()
plt.xticks([])
plt.yticks([])

Y = use(nnet, Xtrain)
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
display(confusion_matrix(Y_classes, Ttrain))

Ytest = use(nnet, Xtest)
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
display(confusion_matrix(Y_classes_test, Ttest))

run_again_mnist([20, 20, 20])

Percent Correct is 100.0

0 1 2 3 4 5 6 7 8 9
0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0
5 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0
6 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0
7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0
8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0
9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0
Percent Correct is 84.25

0 1 2 3 4 5 6 7 8 9
0 87.234043 0.000000 0.000000 2.127660 0.000000 10.638298 0.000000 0.000000 0.000000 0.000000
1 0.000000 93.617021 2.127660 0.000000 0.000000 2.127660 0.000000 0.000000 0.000000 2.127660
2 0.000000 2.777778 91.666667 0.000000 0.000000 0.000000 2.777778 0.000000 2.777778 0.000000
3 0.000000 0.000000 0.000000 79.069767 0.000000 11.627907 0.000000 4.651163 4.651163 0.000000
4 0.000000 0.000000 0.000000 0.000000 85.714286 0.000000 5.714286 2.857143 0.000000 5.714286
5 4.347826 0.000000 2.173913 4.347826 0.000000 76.086957 0.000000 0.000000 13.043478 0.000000
6 2.325581 0.000000 0.000000 0.000000 2.325581 0.000000 95.348837 0.000000 0.000000 0.000000
7 2.777778 2.777778 2.777778 0.000000 0.000000 0.000000 0.000000 86.111111 0.000000 5.555556
8 0.000000 3.030303 3.030303 9.090909 0.000000 3.030303 3.030303 9.090909 69.696970 0.000000
9 0.000000 0.000000 2.941176 0.000000 8.823529 2.941176 0.000000 8.823529 2.941176 73.529412
