#!/usr/bin/env python # coding: utf-8 # This post is available for downloading as [this jupyter notebook](http://nbviewer.ipython.org/url/www.cs.colostate.edu/~anderson/wp/notebooks/Fitting%20Polynomial%20to%20Data%20in%20Pytorch%20with%20autograd.ipynb). # # Table of Contents # * [Very Brief Introduction to Autograd](#Very-Brief-Introduction-to-Autograd) # * [Using Numpy to Fit a Polynomial to Data](#Using-Numpy-to-Fit-a-Polynomial-to-Data) # * [Now, with Pytorch](#Now,-with-Pytorch) # * [Pytorch with Autograd](#Pytorch-with-Autograd) # * [Pytorch with autograd on GPU](#Pytorch-with-autograd-on-GPU) # * [Wrapped up in one function](#Wrapped-up-in-one-function) # # In[1]: import numpy as np import time import torch import matplotlib.pyplot as plt get_ipython().run_line_magic('matplotlib', 'inline') # # Very Brief Introduction to Autograd # Let's use torch.autograd to calculate the derivative of the sine function. Here is what we should get. # In[14]: x = np.linspace(-2*np.pi, 2*np.pi, 100) y = np.sin(x) dy = np.cos(x) plt.plot(x, y) plt.plot(x, dy) plt.legend(('$\sin(x)$', '$\\frac{d \sin(x)}{dx} = \cos(x)$')); # The autograd module in pytorch is designed to calculate gradients of scalar-valued functions with respect to parameters. So, how can we use autograd to calculate $\frac{d \sin(x)}{dx}$ for multiple values of $x$? Well, we can't with a single call to `backward`. Instead, we must include a vector as an argument to `backward` that has as many elements as $x$ has. The dot product of the vector and the calculated gradient is calculated, to sum up the full gradient with respect to each component in the vector. This is just what is needed when calculating a gradient of a model's mean squared error, averaged over all outputs, with respect to the model's parameters. # # Back to our original problem. We can calculate the derivative of $\sin(x)$ with respect to each value of $x$ by calling `backward` once for each value of $x$ and setting the vector argument to `backward` to all zeros except one value corresponding to the position of the value of $x$ we want the derivative for. # # The following bit of code illustrates this. Sharif first showed this to me. [This site](https://discuss.pytorch.org/t/clarification-using-backward-on-non-scalars/1059) also helps. # In[15]: x = torch.autograd.Variable(torch.linspace(-2*np.pi, 2*np.pi, steps=100), requires_grad=True) # The variable `x` will contain the gradient of `y` (defined in next cell) with respect to `x`, but only after `y.backward(...)` is called. Any additional calls to `y.backward(...)` will add gradient values to the current gradient values. The following test and call to `x.grad.data.zero_()` will zero the gradient values, to take care of the case when then next three cells are executed additional times. # In[16]: if x.grad is not None: x.grad.data.zero_() y = torch.sin(x) # In[17]: dout = torch.zeros(100) for i in range(100): dout[:] = 0 dout[i] = 1 y.backward(dout, retain_graph=True) # In[18]: plt.plot(x.data.numpy(), y.data.numpy()); plt.plot(x.data.numpy(), x.grad.data.numpy()) plt.legend(('$\sin(x)$', '$\\frac{d \sin(x)}{dx} = \cos(x)$')); # # Using Numpy to Fit a Polynomial to Data # Let's try to fit a polynomial to the sine function. First, here is the parameterized polynomial model of degree 5 and its derivative. # In[7]: def poly(x, w): ''' poly(x,w), where x is Nx1 samples and w is 1xD+1 coefficients for x^0, x^1, ..., x^D''' D = w.size xPowers = x ** range(D) return xPowers @ w # The derivative of a polynomial of degree 3 is # $$\frac{d (w_0 + w_1 x + w_2 x^2 + w_3 x^3}{dw_i} = (1, x, x^2, x^3)$$ # and in python it is # In[8]: def dpoly_dw(x, w): D = w.size xPowers = x ** range(D) return xPowers # Let's test these functions. # In[9]: x = np.linspace(-5, 5, 20).reshape((-1,1)) w = 0.1*np.array([3.0, -2.0, -1.5, 5]).reshape((-1,1)) x.shape, w.shape # In[10]: poly(x, w).shape, dpoly_dw(x,w).shape # In[11]: plt.subplot(2, 1, 1) plt.plot(poly(x, w)) plt.ylabel('poly(x,w)') plt.subplot(2,1,2) plt.plot(dpoly_dw(x,w)) plt.ylabel('d poly(x,w)/ dw'); # Now, some data to fit. 100 samples of part of the sine curve. # In[12]: x = np.linspace(0, 5, 100).reshape((-1, 1)) y = np.sin(x) plt.plot(x, y, 'o-'); # Okay, ready to fit the polynomial to this data. Steps are simple. Initialize w to zeros. Calculate output of polynomial model. Update w by negative gradient of mean squared error with respect to w, multiplied by a small learning rate. To make plot of mse versus number of update steps, calculate mean squared error between model output and data. # In[13]: startTime = time.time() nSteps = 200000 learnRate = 0.00001 degree = 4 w = np.zeros((degree+1, 1)) mseTrace = np.zeros(nSteps) nSamples = x.shape[0] for step in range(nSteps): yModel = poly(x, w) grad = -2/nSamples * dpoly_dw(x, w).T @ (y - yModel) w -= learnRate * grad if step == 0: print('First step gradient:') print(grad) mse = ((y - yModel)**2).mean() mseTrace[step] = mse print('Numpy took {} seconds'.format(time.time()-startTime)) # In[16]: plt.figure(figsize=(15, 5)) plt.subplot(1, 2, 1) plt.plot(mseTrace) plt.subplot(1, 2, 2) plt.plot(x, y) plt.plot(x, yModel); # # Now, with Pytorch # Now we repeat all of the above function definitions with changes needed for implementation in torch instead of numpy. # In[17]: dtype = torch.FloatTensor # In[23]: def poly_torch(x, w): ''' poly(x,w), where x is Nx1 samples and w is 1xD+1 coefficients for x^0, x^1, ..., x^D''' # D = w.size D = w.shape[0] xPowers = x ** torch.arange(0.0, D) # return xPowers @ w return xPowers.mm(w) # No changes needed from dpoly def dpoly_dw_torch(x, w): D = w.shape[0] xPowers = x ** torch.arange(0.0, D) return xPowers # In[24]: x_torch = torch.from_numpy(x).type(dtype) y_torch = torch.from_numpy(y).type(dtype) # In[25]: startTime = time.time() nSteps = 200000 learnRate = 0.00001 degree = 4 # w = np.zeros((degree+1, 1)) w_torch = torch.zeros((degree+1, 1)).type(dtype) # mseTrace = np.zeros(nSteps) mseTrace = torch.zeros(nSteps) nSamples = x_torch.shape[0] for step in range(nSteps): # yModel = poly(x, w) # grad = -2/nSamples * dpoly(x, w).T @ (y - yModel) yModel = poly_torch(x_torch, w_torch) grad = -2/nSamples * dpoly_dw_torch(x_torch, w_torch).t().mm(y_torch - yModel) w_torch -= learnRate * grad if step == 0: print('First step gradient:') print(grad) mse = ((y_torch - yModel)**2).mean() mseTrace[step] = mse print('Pytorch took {} seconds'.format(time.time()-startTime)) # In[26]: plt.figure(figsize=(15,5)) plt.subplot(1,2,1) plt.plot(mseTrace.numpy()) plt.subplot(1,2,2) plt.plot(x, y) plt.plot(x, yModel.numpy()); # # Pytorch with Autograd # Now let's remove the call to `dpoly_dw_torch` and use autograd to calculate the gradient of the polynomial with respect to w. # In[28]: def poly_torch(x, w): ''' poly(x,w), where x is Nx1 samples and w is 1xD+1 coefficients for x^0, x^1, ..., x^D''' # D = w.size # D = w.shape[0] D = w.data.shape[0] xPowers = x ** torch.autograd.Variable(torch.arange(0.0, D)) # return xPowers @ w return xPowers.mm(w) # In[29]: startTime = time.time() nSteps = 200000 learnRate = 0.00001 degree = 4 # w = np.zeros((degree+1, 1)) # w_torch = torch.zeros((degree+1, 1)).type(dtype) w_torch_Var = torch.autograd.Variable(torch.zeros((degree+1, 1)).type(dtype), requires_grad=True) x_torch_Var = torch.autograd.Variable(x_torch, requires_grad=False) y_torch_Var = torch.autograd.Variable(y_torch, requires_grad=False) # mseTrace = np.zeros(nSteps) mseTrace = torch.zeros(nSteps) nSamples = x_torch_Var.data.shape[0] for step in range(nSteps): # yModel = poly(x, w) # grad = -2/nSamples * dpoly(x, w).T @ (y - yModel) yModel = poly_torch(x_torch_Var, w_torch_Var) mse = ((y_torch_Var - yModel)**2).mean() if step > 0: w_torch_Var.grad.data.zero_() mse.backward() # grad = -2/nSamples * dpoly_torch(x_torch, w_torch).t().mm(y_torch - yModel) w_torch_Var.data -= learnRate * w_torch_Var.grad.data if step == 0: print(w_torch_Var.grad.data) mseTrace[step] = mse.data[0] print('Pytorch with autograd took {} seconds'.format(time.time()-startTime)) # In[31]: plt.figure(figsize=(15,5)) plt.subplot(1,2,1) plt.plot(mseTrace.numpy()) plt.subplot(1,2,2) plt.plot(x, y) plt.plot(x, yModel.data.numpy()); # # Pytorch with autograd on GPU # To run our torch implementation on the GPU, we need to change the data type and also call `cpu()` on variables to move them back to the CPU when needed. # # First, here are the details of the GPU on this machine. # In[28]: get_ipython().system('nvidia-smi') # In[29]: get_ipython().system('uptime') # In[33]: dtype = torch.cuda.FloatTensor # In[32]: def poly_torch(x, w): ''' poly(x,w), where x is Nx1 samples and w is 1xD+1 coefficients for x^0, x^1, ..., x^D''' # D = w.size # D = w.shape[0] dtype = x.data.type() D = w.data.shape[0] xPowers = x ** torch.autograd.Variable(torch.arange(0.0, D).type(type(x.data))) # .type(dtype)) # return xPowers @ w return xPowers.mm(w) # In[42]: startTime = time.time() nSteps = 200000 learnRate = 0.00001 degree = 4 # w = np.zeros((degree+1, 1)) # w_torch = torch.zeros((degree+1, 1)).type(dtype) w_torch_Var = torch.autograd.Variable(torch.zeros((degree+1, 1)).type(dtype), requires_grad=True) x_torch_Var = torch.autograd.Variable(x_torch.type(torch.cuda.FloatTensor), requires_grad=False) y_torch_Var = torch.autograd.Variable(y_torch.type(torch.cuda.FloatTensor), requires_grad=False) # mseTrace = np.zeros(nSteps) mseTrace = torch.zeros(nSteps).type(torch.cuda.FloatTensor) nSamples = x_torch_Var.data.shape[0] for step in range(nSteps): # yModel = poly(x, w) # grad = -2/nSamples * dpoly(x, w).T @ (y - yModel) yModel = poly_torch(x_torch_Var, w_torch_Var) mse = ((y_torch_Var - yModel)**2).mean() if step > 0: w_torch_Var.grad.data.zero_() mse.backward() # grad = -2/nSamples * dpoly_torch(x_torch, w_torch).t().mm(y_torch - yModel) w_torch_Var.data -= learnRate * w_torch_Var.grad.data if step == 0: print(w_torch_Var.grad.data) mseTrace[step] = mse.data[0] print('Pytorch with autograd on GPU took {} seconds'.format(time.time()-startTime)) # In[52]: plt.figure(figsize=(15,5)) plt.subplot(1,2,1) plt.plot(mseTrace.cpu().numpy()) plt.subplot(1,2,2) plt.plot(x, y) plt.plot(x, yModel.data.cpu().numpy()); # # Wrapped up in one function # We can use the type of the data passed into these functions to select code appropriate for use with `numpy.ndarray`, `torch.FloatTensor`, or `torch.autograd.variable.Variable`. # In[43]: def poly(x, w): ''' poly(x,w), where x is Nx1 samples and w is 1xD+1 coefficients for x^0, x^1, ..., x^D''' typex = type(x) if typex is torch.autograd.variable.Variable: D = w.data.shape[0] exponents = torch.autograd.Variable(torch.arange(0.0, D).type(type(x.data))) elif typex is torch.FloatTensor or typex is torch.cuda.FloatTensor: D = w.shape[0] exponents = torch.arange(0.0, D).type(typex) else: # numpy D = w.shape[0] exponents = np.arange(D, dtype=x.dtype) xPowers = x ** exponents if typex is np.ndarray: return xPowers @ w else: return xPowers.mm(w) def dpoly_dw(x, w): typex = type(x) if typex is torch.autograd.variable.Variable: D = w.data.shape[0] exponents = torch.autograd.Variable(torch.arange(0.0, D).type(type(x.data))) elif typex is torch.FloatTensor or typex is torch.cuda.FloatTensor: D = w.shape[0] exponents = torch.arange(0.0, D).type(typex) else: # numpy D = w.shape[0] exponents = np.arange(D,dtype=x.dtype) return x ** exponents # In[44]: def train(x, y, nSteps=200000, learnRate=0.00001, degree=4, use_torch=False, use_autograd=False, use_gpu=False): startTime = time.time() nSamples = x.shape[0] # Make sure use_torch is true if either of use_autograd or use_gpu is true if use_gpu: use_torch = True if use_autograd: use_torch = True # initialize weights to be all zeros w = np.zeros((degree+1, 1), dtype=np.float32) if use_torch: w = torch.from_numpy(w) if use_gpu: w = w.type(torch.cuda.FloatTensor) if use_autograd: w = torch.autograd.Variable(w, requires_grad=True) # Change type of input samples, x, and targets, y if use_torch: x = torch.from_numpy(x).type(torch.FloatTensor) y = torch.from_numpy(y).type(torch.FloatTensor) if use_gpu: x = x.type(torch.cuda.FloatTensor) y = y.type(torch.cuda.FloatTensor) if use_autograd: x = torch.autograd.Variable(x, requires_grad=False) y = torch.autograd.Variable(y, requires_grad=False) # Set up array to store trace of MSE values for plotting later mseTrace = np.zeros(nSteps, dtype=np.float32) if use_torch: mseTrace = torch.from_numpy(mseTrace) if use_gpu: mseTrace = mseTrace.type(torch.cuda.FloatTensor) # Train for nSteps passes through data set for step in range(nSteps): # Forward pass through model, for all samples in x yModel = poly(x, w) # poly uses type of x to figure out what to do # MSE, necessary for autograd. For all, needed for mseTrace. mse = ((y - yModel)**2).mean() # Backward pass to calculate gradient if use_autograd: if step > 0: w.grad.data.zero_() mse.backward() w.data -= learnRate * w.grad.data elif use_torch: grad = -2/nSamples * dpoly_dw(x, w).t().mm(y - yModel) w -= learnRate * grad else: # must be numpy grad = -2/nSamples * dpoly_dw(x, w).T @ (y - yModel) w -= learnRate * grad if use_autograd: mseTrace[step] = mse.data[0] else: mseTrace[step] = mse elapsedTime = time.time() - startTime return {'mseTrace': mseTrace, 'w': w, 'learnRate': learnRate, 'seconds': elapsedTime, 'use_torch':use_torch, 'use_autograd': use_autograd, 'use_gpu': use_gpu} # In[45]: train(x, y, nSteps=200000, learnRate=0.00001, degree=4, use_torch=False, use_autograd=False, use_gpu=False) # In[46]: train(x, y, nSteps=200000, learnRate=0.00001, degree=4, use_torch=True, use_autograd=False, use_gpu=False) # In[47]: train(x, y, nSteps=200000, learnRate=0.00001, degree=4, use_torch=True, use_autograd=True, use_gpu=False) # In[48]: train(x, y, nSteps=200000, learnRate=0.00001, degree=4, use_torch=True, use_autograd=False, use_gpu=True) # In[53]: train(x, y, nSteps=200000, learnRate=0.00001, degree=4, use_torch=True, use_autograd=True, use_gpu=True) # In[55]: for use_torch, use_autograd, use_gpu in ((False, False, False), (True, False, False), (True, False, True), (True, True, False), (True, True, True)): result = train(x, y, nSteps=200000, learnRate=0.00001, degree=4, use_torch=use_torch, use_autograd=use_autograd, use_gpu=use_gpu) if not use_torch: print('{:20} {:6.2f} seconds, final error {:.4f}'.format('numpy', result['seconds'], result['mseTrace'][-1])) elif not use_autograd and not use_gpu: print('{:20} {:6.2f} seconds, final error {:.4f}'.format('torch', result['seconds'], result['mseTrace'][-1])) elif not use_autograd and use_gpu: print('{:20} {:6.2f} seconds, final error {:.4f}'.format('torch-gpu', result['seconds'], result['mseTrace'][-1])) elif use_autograd and not use_gpu: print('{:20} {:6.2f} seconds, final error {:.4f}'.format('torch-autograd', result['seconds'], result['mseTrace'][-1])) elif use_autograd and use_gpu: print('{:20} {:6.2f} seconds, final error {:.4f}'.format('torch-autograd-gpu', result['seconds'], result['mseTrace'][-1])) # These results are obviously for a small data set and a model with very few parameters. As the size of the data set and the number of parameters increase, the advantage of the GPU will become apparent. It is, however, disappointing that autograd increases execution time about 8 times in this simple example. # I would appreciate comments on changes to my code that will result in faster autograd execution.