强化学习圣经-GridWorld实现

import numpy as np
import matplotlib.pyplot as plt

grid_size = 5
posA = [0,1]
primeA = [4,1]
posB = [0,3]
primeB = [2,3]
discount = 0.9
actions = ['L', 'U', 'R', 'D']
actionProb = [[dict({'L':0.25, 'U':0.25, 'R':0.25, 'D':0.25})] * grid_size] * grid_size

#environment
NextState = []
actionReward = []

for i in range(grid_size):
    NextState.append([])
    actionReward.append([])
    for j in range(grid_size):
        next = dict()
        reward = dict()
        if i == 0:
            next['U'] = [i, j]
            reward['U'] = -1.0
        else:
            next['U'] = [i - 1, j]
            reward['U'] = 0.0
        if i == grid_size - 1:
            next['D'] = [i, j]
            reward['D'] = -1.0
        else:
            next['D'] = [i + 1, j]
            reward['D'] = 0.0
        if j == 0:
            next['L'] = [i, j]
            reward['L'] = -1.0
        else:
            next['L'] = [i, j - 1]
            reward['L'] = 0.0
        if j == grid_size - 1:
            next['R'] = [i, j]
            reward['R'] = -1.0
        else:
            next['R'] = [i, j + 1]
            reward['R'] = 0.0
        if [i, j] == posA:
            next['L'] = next['R'] = next['D'] = next['U'] = primeA
            reward['L'] = reward['R'] = reward['D'] = reward['U'] = 10.0
        if [i, j] == posB:
            next['L'] = next['R'] = next['D'] = next['U'] = primeB
            reward['L'] = reward['R'] = reward['D'] = reward['U'] = 5.0

        NextState[i].append(next)
        actionReward[i].append(reward)

#iteration       
choose = ['v(s):' ,'opt_v(s):']

for sel in choose:
    stateValue = np.zeros((grid_size, grid_size))
    while True:
        newStateValue = np.zeros((grid_size, grid_size))
        for i in range(grid_size):
            for j in range(grid_size):
                v_s, opt_v_s = [], []
                for action in actions:
                    newPosition = NextState[i][j][action]
                    v_s.append(actionProb[i][j][action] * (actionReward[i][j][action] + discount * stateValue[newPosition[0], newPosition[1]]))
                    opt_v_s.append(actionReward[i][j][action] + discount * stateValue[newPosition[0], newPosition[1]])
                if sel == 'v(s):':
                    newStateValue[i][j] = np.sum(v_s)
                else:
                    newStateValue[i][j] = np.max(opt_v_s)
        #print(newStateValue)
        if np.sum(np.abs(stateValue - newStateValue)) < 1e-4:
            print(sel)
            print(newStateValue)
            break
        stateValue = newStateValue
  
    plt.matshow(newStateValue, cmap=plt.cm.Greys)
    plt.colorbar()
    plt.title(sel)
    plt.show()

强化学习圣经-GridWorld实现_第1张图片     强化学习圣经-GridWorld实现_第2张图片     强化学习圣经-GridWorld实现_第3张图片

注:https://www.quantinfo.com/Article/View/725/%E5%9F%BA%E4%BA%8E%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0%E7%9A%84GridWorld%EF%BC%88%E4%BB%A3%E7%A0%81+%E6%80%9D%E8%B7%AF%EF%BC%89.html

你可能感兴趣的:(RL)