$\newcommand{\xv}{\mathbf{x}} \newcommand{\Xv}{\mathbf{X}} \newcommand{\yv}{\mathbf{y}} \newcommand{\zv}{\mathbf{z}} \newcommand{\av}{\mathbf{a}} \newcommand{\Wv}{\mathbf{W}} \newcommand{\wv}{\mathbf{w}} \newcommand{\tv}{\mathbf{t}} \newcommand{\Tv}{\mathbf{T}} \newcommand{\muv}{\boldsymbol{\mu}} \newcommand{\sigmav}{\boldsymbol{\sigma}} \newcommand{\phiv}{\boldsymbol{\phi}} \newcommand{\Phiv}{\boldsymbol{\Phi}} \newcommand{\Sigmav}{\boldsymbol{\Sigma}} \newcommand{\Lambdav}{\boldsymbol{\Lambda}} \newcommand{\half}{\frac{1}{2}} \newcommand{\argmax}[1]{\underset{#1}{\operatorname{argmax}}} \newcommand{\argmin}[1]{\underset{#1}{\operatorname{argmin}}}$

**Interpreting What a Neural Network Has Learned**

Explainable Artificial Intelligence (XAI): Concepts, taxonomies, opportunities and challenges toward responsible AI, Arrieta, et al., Information Fusion, Volume 58, June 2020, Pages 82-115

"Given a certain audience, explainability refers to the details and reasons a model gives to make its functioning clear or easy to understand."

Here we will examine what the hidden units in a convolutional neural network have learned. This is most intuitive if we focus on classification problems involving images.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas as pd
import os

In [2]:
from A6mysolution import *

In [3]:
# for regression problem
def rmse(a, b):
return np.sqrt(np.mean((a - b)**2))

# for classification problem
def percent_correct(a, b):
return 100 * np.mean(a == b)

# for classification problem
def confusion_matrix(Y_classes, T):
class_names = np.unique(T)
table = []
for true_class in class_names:
row = []
for Y_class in class_names:
row.append(100 * np.mean(Y_classes[T == true_class] == Y_class))
table.append(row)
conf_matrix = pd.DataFrame(table, index=class_names, columns=class_names)
print(f'Percent Correct is {percent_correct(Y_classes, T)}')
return conf_matrix

In [ ]:


In [4]:
def makeImages(nEach):
images = np.zeros((nEach * 2, 1, 20, 20))  # nSamples, nChannels, rows, columns
radii = 3 + np.random.randint(10 - 5, size=(nEach * 2, 1))
centers = np.zeros((nEach * 2, 2))
for i in range(nEach * 2):
centers[i, :] = r + 1 + np.random.randint(18 - 2 * r, size=(1, 2))
x = int(centers[i, 0])
y = int(centers[i, 1])
if i < nEach:
# squares
images[i, 0, x - r:x + r, y + r] = 1.0
images[i, 0, x - r:x + r, y - r] = 1.0
images[i, 0, x - r, y - r:y + r] = 1.0
images[i, 0, x + r, y - r:y + r + 1] = 1.0
else:
# diamonds
images[i, 0, range(x - r, x), range(y, y + r)] = 1.0
images[i, 0, range(x - r, x), range(y, y - r, -1)] = 1.0
images[i, 0, range(x, x + r + 1), range(y + r, y - 1, -1)] = 1.0
images[i, 0, range(x, x + r), range(y - r, y)] = 1.0
# images += np.random.randn(*images.shape) * 0.5
T = np.zeros((nEach * 2, 1))
T[nEach:] = 1
return images, T

nEach = 1000
X, T = makeImages(nEach)
X = X.reshape(X.shape[0], -1)
print(X.shape, T.shape)

Xtest, Ttest = makeImages(nEach)
Xtest = Xtest.reshape(Xtest.shape[0], -1)

plt.plot(T);

(2000, 400) (2000, 1)

In [5]:
plt.imshow(-X[-1, :].reshape(20, 20), cmap='gray')
plt.xticks([])
plt.yticks([])

Out[5]:
([], [])
In [6]:
plt.figure(figsize=(10, 3))

for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow(-X[i, :].reshape(20,20), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(2, 10, i + 11)
plt.imshow(-X[-i, :].reshape(20,20), cmap='gray')
plt.xticks([])
plt.yticks([])

In [7]:
nnet, learning_curve = train_for_classification(X, T, hidden_layers=[10],
n_epochs=500, learning_rate=0.01)
plt.plot(learning_curve);

In [8]:
nnet

Out[8]:
Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
In [9]:
Y = use(nnet, X)
Ytest = use(nnet, Xtest)
Y.shape

Out[9]:
(2000, 2)
In [10]:
plt.subplot(2, 1, 1)
plt.plot(Y)
plt.subplot(2, 1, 2)
plt.plot(Ytest)

Out[10]:
[<matplotlib.lines.Line2D at 0x7fa89861d370>,
<matplotlib.lines.Line2D at 0x7fa89861d460>]
In [11]:
plt.plot(np.exp(Y))

Out[11]:
[<matplotlib.lines.Line2D at 0x7fa898582b50>,
<matplotlib.lines.Line2D at 0x7fa898582c40>]
In [12]:
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)  # To keep 2-dimensional shape
plt.plot(Y_classes, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();

In [13]:
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)  # To keep 2-dimensional shape
plt.plot(Y_classes_test, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();

In [14]:
Y.shape

Out[14]:
(2000, 2)
In [15]:
confusion_matrix(Y_classes_test, Ttest)

Percent Correct is 99.05000000000001

Out[15]:
0.0 1.0
0.0 99.1 0.9
1.0 1.0 99.0
In [16]:
def forward_all_layers(nnet, X):

X = torch.from_numpy(X).float()
Ys = [X]
for layer in nnet:
Ys.append(layer(Ys[-1]))

Ys = [Y.detach().numpy() for Y in Ys]
return Ys

In [17]:
Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])

In [18]:
nnet

Out[18]:
Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
In [19]:
len(Y_square)

Out[19]:
5
In [20]:
Y_square[0].shape

Out[20]:
(10, 400)
In [21]:
Y_square[1].shape

Out[21]:
(10, 10)
In [22]:
plt.plot(Y_square[1]);

In [23]:
plt.plot(Y_square[2]);

In [24]:
both = np.vstack((Y_square[2], Y_diamond[2]))

In [25]:
plt.plot(both);

In [26]:
plt.figure(figsize=(15, 3))
for unit in range(10):
plt.subplot(1, 10, unit + 1)
plt.plot(both[:, unit])
plt.tight_layout()

In [27]:
plt.plot(both[:, 9])

Out[27]:
[<matplotlib.lines.Line2D at 0x7fa898a4fcd0>]
In [28]:
nnet

Out[28]:
Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
In [29]:
nnet[0].parameters()

Out[29]:
<generator object Module.parameters at 0x7fa898a1eac0>
In [30]:
list(nnet[0].parameters())

Out[30]:
[Parameter containing:
tensor([[-0.0235,  0.0183, -0.0043,  ..., -0.0283, -0.0472,  0.0493],
[-0.0305,  0.0305,  0.0375,  ...,  0.0076, -0.0254, -0.0216],
[-0.0205, -0.0357, -0.0086,  ..., -0.0464,  0.0402,  0.0307],
...,
[-0.0139,  0.0290,  0.0346,  ..., -0.0098,  0.0007,  0.0092],
[ 0.0226,  0.0153,  0.0103,  ..., -0.0146,  0.0159,  0.0084],
[ 0.0174,  0.0253, -0.0183,  ...,  0.0091,  0.0213,  0.0307]],
Parameter containing:
tensor([-0.8432, -0.7838, -0.5426, -0.7839,  0.6757, -0.6754, -0.7834,  0.7684,
0.8641,  0.5593], requires_grad=True)]
In [31]:
W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W.shape

Out[31]:
(10, 400)
In [32]:
W = W.T
W.shape

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-32-4b76fe59cac6> in <module>
----> 1 W = W.TRdYlGn
2 W.shape

AttributeError: 'numpy.ndarray' object has no attribute 'TRdYlGn'
In [ ]:
plt.plot(W);

In [ ]:
plt.plot(W[:, 0])

In [ ]:
plt.imshow(W[:, 0].reshape(20, 20), cmap='RdYlGn')
plt.colorbar()

In [ ]:
plt.figure(figsize=(15, 3))

for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow(W[:, i].reshape(20,20), cmap='RdYlGn')
plt.xticks([])
plt.yticks([])
plt.colorbar()

plt.subplot(2, 10, i + 11)
plt.plot(both[:, i])

In [ ]:
X.shape

In [ ]:
plt.imshow(X[4,:].reshape(20, 20), cmap='gray')


Let's automate these steps in a function, so we can try different numbers of hidden units and layers.

In [33]:
nnet

Out[33]:
Sequential(
(0): Linear(in_features=400, out_features=10, bias=True)
(1): Tanh()
(2): Linear(in_features=10, out_features=2, bias=True)
(3): LogSoftmax()
)
In [34]:
Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T
Wout.shape

Out[34]:
(10, 2)
In [35]:
def run_again(hiddens):

nnet, learning_curve = train_for_classification(X, T, hidden_layers=hiddens,
n_epochs=1000, learning_rate=0.01)
plt.figure()
plt.plot(learning_curve)

Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
both = np.vstack((Y_square[2], Y_diamond[2]))

W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W = W.T

Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T

plt.figure(figsize=(15, 3))

n_units = hiddens[0]

size = int(np.sqrt(X.shape[1]))

for i in range(n_units):
plt.subplot(2, n_units, i + 1)
plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
plt.colorbar()
plt.xticks([])
plt.yticks([])

plt.subplot(2, n_units, i + 1 + n_units)
plt.plot(both[:, i])
plt.title(f'{Wout[i,0]:.1f},{Wout[i,1]:.1f}')

Y = use(nnet, X)
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
print(confusion_matrix(Y_classes, T))

Ytest = use(nnet, Xtest)
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
print(confusion_matrix(Y_classes_test, Ttest))

In [36]:
run_again([10])

Percent Correct is 100.0
0.0    1.0
0.0  100.0    0.0
1.0    0.0  100.0
Percent Correct is 99.45
0.0   1.0
0.0  99.8   0.2
1.0   0.9  99.1

In [ ]:


In [37]:
if os.path.isfile('small_mnist.npz'):
else:
import shlex
import subprocess
cmd = 'curl "https://www.cs.colostate.edu/~anderson/cs545/notebooks/small_mnist.npz" -o "small_mnist.npz"'
subprocess.call(shlex.split(cmd))

X = small_mnist['X']
T = small_mnist['T']

X.shape, T.shape

Reading data from 'small_mnist.npz'.

Out[37]:
((1000, 784), (1000, 1))
In [38]:
plt.imshow(-X[0, :].reshape(28, 28), cmap='gray')

Out[38]:
<matplotlib.image.AxesImage at 0x7fa87459ccd0>

Randomly partition the data into 80% for training and 20% for testing, using the following code cells.

In [39]:
n_samples = X.shape[0]
n_train = int(n_samples * 0.6)
rows = np.arange(n_samples)
np.random.shuffle(rows)

Xtrain = X[rows[:n_train], :]
Ttrain = T[rows[:n_train], :]
Xtest = X[rows[n_train:], :]
Ttest = T[rows[n_train:], :]

In [59]:
def run_again_mnist(hiddens):

nnet, learning_curve = train_for_classification(Xtrain, Ttrain, hidden_layers=hiddens,
n_epochs=1000, learning_rate=0.01)
plt.figure()
plt.plot(learning_curve)

Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
both = np.vstack((Y_square[2], Y_diamond[2]))

W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W = W.T

Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T

plt.figure(figsize=(15, 15))

n_units = hiddens[0]

size = int(np.sqrt(X.shape[1]))

n_rows = int(np.sqrt(n_units) + 1)
for i in range(n_units):
plt.subplot(n_rows, n_rows, i + 1)
plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
plt.colorbar()
plt.xticks([])
plt.yticks([])

Y = use(nnet, Xtrain)
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
display(confusion_matrix(Y_classes, Ttrain))

Ytest = use(nnet, Xtest)
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
display(confusion_matrix(Y_classes_test, Ttest))

In [60]:
run_again_mnist([20, 20, 20])

Percent Correct is 100.0

0 1 2 3 4 5 6 7 8 9
0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0
5 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0
6 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0
7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0
8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0
9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0
Percent Correct is 84.25

0 1 2 3 4 5 6 7 8 9
0 87.234043 0.000000 0.000000 2.127660 0.000000 10.638298 0.000000 0.000000 0.000000 0.000000
1 0.000000 93.617021 2.127660 0.000000 0.000000 2.127660 0.000000 0.000000 0.000000 2.127660
2 0.000000 2.777778 91.666667 0.000000 0.000000 0.000000 2.777778 0.000000 2.777778 0.000000
3 0.000000 0.000000 0.000000 79.069767 0.000000 11.627907 0.000000 4.651163 4.651163 0.000000
4 0.000000 0.000000 0.000000 0.000000 85.714286 0.000000 5.714286 2.857143 0.000000 5.714286
5 4.347826 0.000000 2.173913 4.347826 0.000000 76.086957 0.000000 0.000000 13.043478 0.000000
6 2.325581 0.000000 0.000000 0.000000 2.325581 0.000000 95.348837 0.000000 0.000000 0.000000
7 2.777778 2.777778 2.777778 0.000000 0.000000 0.000000 0.000000 86.111111 0.000000 5.555556
8 0.000000 3.030303 3.030303 9.090909 0.000000 3.030303 3.030303 9.090909 69.696970 0.000000
9 0.000000 0.000000 2.941176 0.000000 8.823529 2.941176 0.000000 8.823529 2.941176 73.529412
In [ ]: