#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('matplotlib', 'inline') import numpy as np import pandas as pd from pandas import Series, DataFrame import matplotlib.pyplot as plt import copy # In[2]: def get_max_estimate_index(estimates): max = estimates[np.argmax(estimates)] maxes = [] for i in range(10): if max == estimates[i]: maxes.append(i) return maxes # In[3]: def random(array): a = copy.deepcopy(array) np.random.shuffle(a) return a[0] # In[4]: def probability(epsilon): if np.random.rand(1)[0] <= epsilon: return True return False # In[5]: Qstar = np.random.randn(2000, 10) # In[6]: def play(Qstar, epsilon, repeat): estimates = np.zeros([2000,10]) history = [[[] for j in range(10)] for i in range(2000)] result = [] for j in range(repeat): total = 0 for i in range(2000): indexes = get_max_estimate_index(estimates[i]) if probability(epsilon): indexes = range(10) n = random(indexes) reward = Qstar[i][n] + np.random.randn() history[i][n].append(reward) estimates[i][n] = sum(history[i][n])/len(history[i][n]) total += reward result.append(total/2000) return result # In[7]: zero = play(Qstar, 0, 500) # In[8]: one = play(Qstar, 0.01, 500) # In[9]: ten = play(Qstar, 0.1, 500) # In[10]: fifty = play(Qstar, 0.5, 500) # In[18]: hundred = play(Qstar, 1, 500) # In[35]: def play2(Qstar, repeat): estimates = np.zeros([2000,10]) history = [[[] for j in range(10)] for i in range(2000)] result = [] for j in range(repeat): total = 0 if j <= repeat * 0.8: epsilon = 0.1 else: epsilon = 0.01 for i in range(2000): indexes = get_max_estimate_index(estimates[i]) if probability(epsilon): indexes = range(10) n = random(indexes) reward = Qstar[i][n] + np.random.randn() history[i][n].append(reward) estimates[i][n] = sum(history[i][n])/len(history[i][n]) total += reward result.append(total/2000) return result # In[37]: one_to_zero = play2(Qstar, 500) # In[42]: plt.title("n-Armed Bandit Problem") plt.xlabel("number of times") plt.ylabel("average of total reward") plt.ylim(-0.2, 1.6) plt.xlim(0, 500) plt.grid() plt.plot(zero,label='e = 0') plt.plot(one, label='e = 0.01') plt.plot(ten, label='e = 0.1') plt.plot(fifty, label='e = 0.5') plt.plot(hundred, label='e = 1') plt.legend(loc='lower right') # In[41]: plt.title("n-Armed Bandit Problem") plt.xlabel("number of times") plt.ylabel("average of total reward") plt.ylim(-0.2, 1.6) plt.xlim(0, 500) plt.grid() plt.plot(one, label='e = 0.01') plt.plot(ten, label='e = 0.1') plt.plot(one_to_zero, label='e = 0.1 to 0.01') plt.legend(loc='lower right')