#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np import matplotlib.pyplot as plt import IPython.display as ipd # for display and clear_output import time # for sleep # Last time we discussed using pytorch to construct, train, and use neural networks as regression models. "Regression" refers to the output of continuous values, like rainfall amounts or stock prices. # # Today we will modify the code in order to make classification models. These are models that output discrete values, or categorical values, representing class labels for each sample, such as "mask" or "no mask" from images of people, and "excited" or "calm" from speech signals. # # From Regression to Classification # # The primary changes we need to consider is that we use a different output function for the network, a different loss function for classification, and our target outputs are now class labels. # ## Network Output # # For regression, we just output the weighted sum of inputs coming into the output layer as our prediction of the correct output for each sample. # # For classification, we want to convert these output values into class probabilities, or the probabilities that a given input sample should be classified as each of the possible classes. So, if we have 2 classes, like "cat" or "dog", we need 2 outputs. For these two outputs to be probabilities, we want each one to be between 0 and 1 and for a given sample we want the two values to sum to 1. # # We will accomplish this by passing the outputs of the neural network through a "softmax" function. Let's call the two outputs of our network for input sample $n$, $p_{n,0}$ and $p_{n,1}$. where # # $$\begin{align*} # 0 \le p_{n,0} \le 1\\ # 0 \le p_{n,1} \le 1\\ # p_{n,0} + p_{n,1} = 1 # \end{align*}$$ # # The softmax function that converts our network outputs, $y_{n,0}$, and $y_{n, 1}$, to class probabilities uses each $y$ as an exponent of base $e$ and # dividing each result by the sum of these exponentiated values: # # $$ \begin{align*} # p_{n,i} = \frac{e^{y_{n,i}}}{\sum_{j=0}^K e^{y_{n,j}}} # \end{align*}$$ # # where $K$ is the number of classes. # # Let's do this in python. # In[2]: def softmax(Y): '''Y is n_samples X n_classes''' expY = np.exp(Y) P = expY / np.sum(expY, axis=1) return P # In[3]: get_ipython().run_line_magic('pinfo', 'softmax') # In[4]: Y = np.array([[-2.3, 8.2]]) Y # In[5]: P = softmax(Y) P # In[6]: np.sum(P) # Does this work for multiple samples, in rows of `Y`? # In[7]: n_samples = 5 n_classes = 3 Y = np.random.uniform(-10, 10, size=(n_samples, n_classes)) Y # In[8]: P = softmax(Y) P # How can we fix this? # In[9]: np.sum(Y, axis=1, keepdims=True) # In[10]: def softmax(Y): '''Y is n_samples X n_classes''' expY = np.exp(Y) P = expY / np.sum(expY, axis=1, keepdims=True) return P # In[11]: P = softmax(Y) P # In[12]: P.sum(axis=1) # Once we have class probabilities, how do we convert these to classes, or categories? # # We just created output matrix `Y` with 5 samples and 3 classes representing, let's say, "hot", "warm", and "cold". Now we want to know for each of the 5 samples, was that sample "hot", "warm", or "cold"? # # To do this, we just need to identify which of the three class probabilities was largest for each sample. Hey, maybe `numpy` has a function for this? # In[13]: np.argmax(P, axis=1) # If we have an `np.array` of class names, we can use these integers as indices into the class names. # In[14]: class_names = np.array(['hot', 'warm', 'cold']) class_names[np.argmax(P, axis=1)] # Whenever we are dividing, we have to watch out for divide-by-zero errors. This could happen in our `softmax` function if all of the `Y` values are large negative values. # In[15]: Y = np.random.uniform(-10000010, -10000000, size=(5, 3)) Y # In[16]: np.exp(Y) # In[17]: softmax(Y) # We can deal with that by a simple division of the numerator and denominator by $e^{\text{max}_j(y_{n, j})}$ # $$ \begin{align*} # p_{n,i} = \frac{e^{y_{n,i}}}{\sum_{j=0}^K e^{y_{n,j}}} \frac{e^{-\text{max}_j(y_{n,j})}}{e^{-\text{max}_j(y_{n,j})}} # \end{align*}$$ # # In[18]: def softmax(Y): '''Y is n_samples X n_classes''' maxY_by_row = np.max(Y, axis=1, keepdims=True) expY = np.exp(Y - maxY_by_row) P = expY / np.sum(expY, axis=1, keepdims=True) return P # In[19]: softmax(Y) # ## Loss Function # For regression, we used the `torch.nn.MSELoss()` loss function, because we wanted to minimize the mean-squared-error between all training desired target values and the outputs produced by the neural network. # # For classification, we will instead want to maximize the data likelihood, which is the product of all of the correct class probabilities over all samples. Remember that we are using gradient descent to optimize our loss functions. Now, the gradient (or derivative) of a product of a bunch of things is a very computationally heavy calculation. (Why?) So, instead of optimizing this product of probabilities, we will optimize the log of that product, which converts it into a sum of logs of those probabilities. We can do this because the weight values that optimize the product of probabilities are the same weight values that optimize the log of that product! # # So we want to maximize the log-likelihood. Recall that the `torch.optim` functions are designed to minimize a loss (hence the name "loss"). Since we want to maximize the log-likelihood, we must define the negative-log-likelihood to be used by `torch.optim`. # # To do this in our python code, we simply replace # ```python # torch.nn.MSELoss() # ``` # with # ```python # torch.nn.NLLLoss() # ``` # ## Targets as Class Labels # # The final step to convert our code from regression to classification is to construct the correct target, `T`, matrix. For regression, we would create for `T` an `n_samples` x `n_outputs` matrix of desired output values. For classification, we instead create a matrix `T` of `n_samples` values, regardless of how many outputs, or classes, we will have. The values of `T` must be from the set $\{0, 1, \ldots, K-1\}$ where $K$ is the number of different class labels we have. We can get the number of different class labels using python like # ```python # n_classes = len(np.unique(T)) # ``` # # Classification Data # Let's start with a toy problem. Say we have samples, each consisting of three integers. Define the classes as # * class 0: the first integer is greater than the sum of the other two, # * class 1: the second integer is greater than the sum of the other two, # * class 2: if not class 0 or 1. # Making different data to have one input, so can make plots in the animation below. # In[33]: n_samples = 100 X = np.random.uniform(0, 10, size=(n_samples, 1)) T = np.array([0 if (s[0] < 3) else 1 if (s[0] < 6) else 2 for s in X]).reshape(-1, 1) Xtest = np.random.uniform(0, 10, size=(n_samples, 1)) Ttest = np.array([0 if (s[0] < 3) else 1 if (s[0] < 6) else 2 for s in Xtest]).reshape(-1, 1) X, T # In[34]: np.hstack((X, T)) # In[35]: np.sum(T == 0), np.sum(T == 1), np.sum(T == 2) # # Create, Train, and Use our Classifier # In[36]: import torch torch.__version__ # Remember that `pytorch` requires inputs of single precision float. The `NLLLoss` function requires target class labels to be one-dimensional and double ints. # In[37]: torch.from_numpy(X).float() # In[38]: torch.from_numpy(T).reshape(-1).long() # In[39]: np.unique(T) # In[40]: n_inputs = X.shape[1] n_classes = len(np.unique(T)) Xt = torch.from_numpy(X).float() Tt = torch.from_numpy(T).reshape(-1).long() Xtestt = torch.from_numpy(Xtest).float() Ttestt = torch.from_numpy(Ttest).reshape(-1).long() nnet = torch.nn.Sequential(torch.nn.Linear(n_inputs, 10), torch.nn.Tanh(), torch.nn.Linear(10, 20), torch.nn.Tanh(), torch.nn.Linear(20, 10), torch.nn.Tanh(), torch.nn.Linear(10, n_classes), torch.nn.LogSoftmax(dim=1)) learning_rate = 0.01 n_epochs = 10000 optimizer = torch.optim.SGD(nnet.parameters(), lr=learning_rate) nll_f = torch.nn.NLLLoss() likelihood_trace = [] likelihood_test_trace = [] fig = plt.figure(figsize=(10, 12)) def forward_all_layers(X): Ys = [X] for layer in nnet: Ys.append(layer(Ys[-1])) return Ys[1:] for epoch in range(n_epochs): logP = nnet(Xt) nll = nll_f(logP, Tt) # mse = mse_f(Y, Tt) optimizer.zero_grad() nll.backward() optimizer.step() # error traces for plotting likelihood_trace.append((-nll).exp()) logPtest = nnet(Xtestt) likelihood_test_trace.append((-nll_f(logPtest, Ttestt)).exp()) if epoch % 1000 == 0 or epoch == n_epochs-1: plt.clf() n_hidden_layers = (len(nnet) - 1) //2 nplots = 2 + n_hidden_layers plt.subplot(nplots, 1, 1) plt.plot(likelihood_trace[:epoch]) plt.plot(likelihood_test_trace[:epoch]) # plt.ylim(0, 0.7) plt.xlabel('Epochs') plt.ylabel('Likelihood') plt.legend(('Train','Test'), loc='upper left') plt.subplot(nplots, 1, 2) classes = logPtest.argmax(axis=1) order = np.argsort(X, axis=0).reshape(-1) plt.plot(X[order,:], T[order,:], 'o-', label='Train') order = np.argsort(Xtest, axis=0).reshape(-1) plt.plot(Xtest[order, :], Ttest[order, :], 'o-', label='Test') plt.plot(Xtest[order, :], classes[order], 'o-') plt.legend(('Training','Testing','Model'), loc='upper left') plt.xlabel('$x$') plt.ylabel('Actual and Predicted Class') Ys = forward_all_layers(Xt) Z = Ys[:-1] ploti = 2 for layeri in range(n_hidden_layers, 0, -1): ploti += 1 plt.subplot(nplots, 1, ploti) order = np.argsort(X, axis=0).reshape(-1) plt.plot(X[order,:], Z[layeri * 2 - 1][order,:].detach()) plt.xlabel('$x$') plt.ylabel(f'Hidden Layer {layeri}'); ipd.clear_output(wait=True) ipd.display(fig) ipd.clear_output(wait=True) # In[ ]: