格子世界程序(策略评估、策略迭代、价值迭代)
import ramdom
S = [i for i in range(16)]
A = ["n","e","s","w"]
ds_actions = {"n":-4,"e":1,"s":4,"w":-1}
V = [0 for _ in range(16)]
def dynamics(s,a):
s_prime = s
if (s<=3 and a=="n") or (s>=12 and a=="s") or (s%4==0 and a=="w") or ((s+1)%4==0 and a=="e") or (s in [0,15]):
pass
else:
ds = ds_actions[a]
s_prime = s + ds
reward = 0 if s in [0,15] else -1
is_end = True if s in [0,15] else False
return s_prime,reward,is_end
def P(s,a,s1):
s_prime,_,_ = dynamics(s,a)
return s1 ==s_prime
def R(s,a):
_,r,_ = dynamics(s,a)
return r
def get_value(V,s):
return V[s]
def uniform_random_pi(MDP,V=None,s=None,a=None):
_,A,_,_,_ = MDP
return 1.0/len(A) if len(A)!=0 else 0.0
def greedy_pi(MDP,V,s,a):
_,A,_,_,_ = MDP
max_v = -float('inf')
a_max_v = []
for a_opt in A:
s_prime,_,_ = dynamics(s,a_opt)
if max_v < get_value(V,s_prime):
max_v = get_value(V,s_prime)
a_max_v = [a_opt]
elif max_v == get_value(V,s_prime):
a_max_v.append(a_opt)
return 1.0/len(a_max_v) if (len(a_max_v) != 0) and (a in a_max_v) else 0.0
def get_Pi(Pi,MDP,V,s,a):
return Pi(MDP,V,s,a)
def update_v(MDP,V,Pi):
S,A,_,_,_ = MDP
V_prime = V
for s in S:
v_temp = 0.0
for a in A:
s_prime,_,_ = dynamics(s,a)
v_temp += get_Pi(Pi,MDP,V,s,a)*(R(s,a)+get_value(V,s_prime))
V_prime[s] = v_temp
return V_prime
def policy_evalute(MDP,V,Pi,n):
update_V = V.copy()
for i in range(n):
update_V = update_v(MDP,update_V,Pi)
return update_V
def display_V(V):
for i in range(16):
print('{0:>6.2f}'.format(V[i]),end=" ")
if (i+1)%4 == 0:
print("")
return
def policy_iterate(MDP,V,Pi,n,m):
V_prime = V
for i in range(m):
V_prime = policy_evalute(MDP,V_prime,Pi,n)
print("第 %s 次策略迭代:" % str(i + 1))
display_V(V_prime)
return V_prime
def value_iterate(MDP,V,m):
S,A,_,_,_ = MDP
V_update = V.copy()
for i in range(m):
for s in S:
v_prime_max = -float("inf")
for a in A:
s_prime,reward,_ = dynamics(s,a)
if v_prime_max < (reward + gamma*get_value(V_update,s_prime)):
v_prime_max = reward + gamma*get_value(V_update,s_prime)
V_update[s] = v_prime_max
print("第%s次迭代:"% str(i+1))
display_V(V_update)
return V_update
策略评估
gamma = 1
MDP = S,A,R,P,gamma
update_V = policy_evalute(MDP,V,uniform_random_pi,100)
print("随机策略价值函数评估:")
display_V(update_V)
随机策略价值函数评估:
0.00 -14.00 -20.00 -22.00
-14.00 -18.00 -20.00 -20.00
-20.00 -20.00 -18.00 -14.00
-22.00 -20.00 -14.00 0.00
gamma = 1
MDP = S,A,R,P,gamma
update_V = policy_evalute(MDP,V,greedy_pi,100)
print("贪婪策略价值函数评估:")
display_V(update_V)
贪婪策略价值函数评估:
0.00 -1.00 -2.00 -3.00
-1.00 -2.00 -3.00 -2.00
-2.00 -3.00 -2.00 -1.00
-3.00 -2.00 -1.00 0.00
策略迭代
print("策略迭代价值函数:")
update_V = policy_iterate(MDP,V,greedy_pi,1,100)
策略迭代价值函数:
第 1 次策略迭代:
0.00 -1.00 -1.00 -1.00
-1.00 -1.00 -1.00 -1.00
-1.00 -1.00 -1.00 -1.00
-1.00 -1.00 -1.00 0.00
第 2 次策略迭代:
0.00 -1.00 -2.00 -2.00
-1.00 -2.00 -2.00 -2.00
-2.00 -2.00 -2.00 -1.00
-2.00 -2.00 -1.00 0.00
第 3 次策略迭代:
0.00 -1.00 -2.00 -3.00
-1.00 -2.00 -3.00 -2.00
-2.00 -3.00 -2.00 -1.00
-3.00 -2.00 -1.00 0.00
...
价值迭代
print("价值迭代价值函数:")
update_V = value_iterate(MDP,V,100)
价值迭代价值函数:
第1次迭代:
0.00 -1.00 -1.00 -1.00
-1.00 -1.00 -1.00 -1.00
-1.00 -1.00 -1.00 -1.00
-1.00 -1.00 -1.00 0.00
第2次迭代:
0.00 -1.00 -2.00 -2.00
-1.00 -2.00 -2.00 -2.00
-2.00 -2.00 -2.00 -1.00
-2.00 -2.00 -1.00 0.00
第3次迭代:
0.00 -1.00 -2.00 -3.00
-1.00 -2.00 -3.00 -2.00
-2.00 -3.00 -2.00 -1.00
-3.00 -2.00 -1.00 0.00
...