import numpy as np
from quantecon import MarkovChain
from __future__ import division
class MDP:
def __init__(self, r, delta, Q):
self.r=np.array(r)
self.delta=delta
self.Q=np.array(Q)
self.n=self.r.shape[0] # number of states
self.m=self.r.shape[1] # number of actions
def Bellman_Operator(self, w):
w=np.array(w)
objf=self.r+self.delta*self.Q.dot(w)
Bw=np.max(objf, axis=1)
policy=np.argmax(objf, axis=1)
return Bw,policy
def solve_bellman_equation(self,w0, tol, max_iter):
# Iteration to obtain value function and optimal policy
w_current=np.array(w0)
for t in range(max_iter+1):
w_next, policy=self.Bellman_Operator(w_current)
if np.max(abs(w_current-w_next))<tol:
break
w_current=w_next
# Obtain transition matrix given optimal policy
P=np.zeros((self.n, self.n))
for i in range(self.n):
for j in range(self.n):
P[i,j]=self.Q[i, policy[i], j]
return w_current, policy, P
# Parameters
beta = 0.5
delta = 0.9
B = 10
M = 5
# Period utility function
r = np.zeros((B+M+1, M+1))
for s in range(B+M+1):
for a in range(M+1):
if a <= min(s, M):
r[s, a] = (s - a) ** beta
else:
r[s, a] = -np.inf
# Transition matrix
Q = np.zeros((B+M+1, M+1, B+M+1))
for s in range(B+M+1):
for a in range(M+1):
for t in range(B+M+1):
if a <= t <= a+B:
Q[s, a, t] = 1 / (B + 1)
else:
Q[s, a, t] = 0
print(r)
print(Q)
[[ 0. -inf -inf -inf -inf -inf] [ 1. 0. -inf -inf -inf -inf] [ 1.41421356 1. 0. -inf -inf -inf] [ 1.73205081 1.41421356 1. 0. -inf -inf] [ 2. 1.73205081 1.41421356 1. 0. -inf] [ 2.23606798 2. 1.73205081 1.41421356 1. 0. ] [ 2.44948974 2.23606798 2. 1.73205081 1.41421356 1. ] [ 2.64575131 2.44948974 2.23606798 2. 1.73205081 1.41421356] [ 2.82842712 2.64575131 2.44948974 2.23606798 2. 1.73205081] [ 3. 2.82842712 2.64575131 2.44948974 2.23606798 2. ] [ 3.16227766 3. 2.82842712 2.64575131 2.44948974 2.23606798] [ 3.31662479 3.16227766 3. 2.82842712 2.64575131 2.44948974] [ 3.46410162 3.31662479 3.16227766 3. 2.82842712 2.64575131] [ 3.60555128 3.46410162 3.31662479 3.16227766 3. 2.82842712] [ 3.74165739 3.60555128 3.46410162 3.31662479 3.16227766 3. ] [ 3.87298335 3.74165739 3.60555128 3.46410162 3.31662479 3.16227766]] [[[ 0.09090909 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0.09090909]] [[ 0.09090909 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0.09090909]] [[ 0.09090909 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0.09090909]] ..., [[ 0.09090909 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0.09090909]] [[ 0.09090909 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0.09090909]] [[ 0.09090909 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0.09090909 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0.09090909 ..., 0. 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0. 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0. ] [ 0. 0. 0. ..., 0.09090909 0.09090909 0.09090909]]]
kurtzbellman = MDP(r, delta, Q)
value_function, policy, P = kurtzbellman.solve_bellman_equation(np.zeros(B+M+1), tol=1e-10, max_iter=10000)
print(value_function)
print (policy)
print (P)
[ 19.01740222 20.01740222 20.43161578 20.74945302 21.04078099 21.30873018 21.54479816 21.76928181 21.98270358 22.18824323 22.3845048 22.57807736 22.76109127 22.94376708 23.11533996 23.27761762] [0 0 0 0 1 1 1 2 2 3 3 4 5 5 5 5] [[ 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. 0. 0. ] [ 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. 0. 0. ] [ 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. 0. 0. ] [ 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. 0. 0. ] [ 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. 0. ] [ 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. 0. ] [ 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. 0. ] [ 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. ] [ 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. 0. ] [ 0. 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. ] [ 0. 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. 0. ] [ 0. 0. 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0. ] [ 0. 0. 0. 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909] [ 0. 0. 0. 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909] [ 0. 0. 0. 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909] [ 0. 0. 0. 0. 0. 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909 0.09090909]]
class MDP_pol_iter:
def __init__(self, r, delta, Q):
self.r=np.array(r)
self.delta=delta
self.Q=np.array(Q)
self.n=self.r.shape[0] # number of states
self.m=self.r.shape[1] # number of actions
def value_of_policy(self, policy):
# Obtain transition matrix given policy
P_pol=np.zeros((self.n, self.n))
policy=np.array(policy)
for i in range(self.n):
for j in range(self.n):
P_pol[i,j]=self.Q[i, policy[i], j]
# Expected utility by following given policy
ex_r=np.dot(P_pol,self.r)
discount=1
# Sum of discounted expected payoff given policy
v_pol=np.zeros(ex_r.shape)
for t in range(50):
v_pol=v_pol+discount*ex_r
ex_r=np.dot(P_pol, ex_r)
discount=discount*self.delta
return v_pol
def greedy(self,w):
w=np.array(w)
objf=self.r+self.delta*self.Q.dot(w)
policy=np.argmax(objf, axis=1)
return policy
def policy_iteration(self,policy0,max_iter):
policy_current=np.array(policy0)
for t in range(max_iter+1):
policy_next=self.greedy(self.value_of_policy(policy_current))
e=policy_next-policy_current
if e==np.zeros((self.n)):
break
policy_current=policy_next
return policy_current
kurtzsigma = MDP_pol_iter(r, delta, Q)
policy_iter=P = kurtzsigma.policy_iteration(np.zeros(B+M+1),max_iter=10000)
print(policy_iter)
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-71-35f0d75806cd> in <module>() ----> 1 policy_iter=P = kurtzsigma.policy_iteration(np.zeros(B+M+1),max_iter=10000) 2 3 print(policy_iter) <ipython-input-69-1ab6d1104c5e> in policy_iteration(self, policy0, max_iter) 42 policy_current=np.array(policy0) 43 for t in range(max_iter+1): ---> 44 policy_next=self.greedy(self.value_of_policy(policy_current)) 45 e=policy_next-policy_current 46 if e==np.zeros((self.n)): <ipython-input-69-1ab6d1104c5e> in greedy(self, w) 34 def greedy(self,w): 35 w=np.array(w) ---> 36 objf=self.r+self.delta*self.Q.dot(w) 37 policy=np.argmax(objf, axis=1) 38 return policy ValueError: operands could not be broadcast together with shapes (16,6) (16,6,6)