强化学习实例9:时序差分法(Temporal Difference)

时序差分法(Temporal Difference,简称TD法),是一种结合了蒙特卡罗法和动态规划法的方法。

通过蒙特卡罗法得到

通过TD法得到

其中称为TD目标

TD使用了当前回报和下一时刻的价值估计,所以整体系统没有达到最优,这样的估计是有偏差的,但方差减少。

而MC使用完整的采样得到了长期回报值,所以估计偏差小,但方差大。

代码如下:

# TD 之 SARSA
class SARSA(object):
    def __init__(self, epsilon=0.0):
        self.epsilon = epsilon
    def sarse_eval(self, agent, env):
        state = env.reset()
        prev_state = -1
        prev_act = -1
        while True:
            act = agent.play(state, self.epsilon)
            next_state, reward, terminate, _ = env.step(act)
            if prev_act != -1:
                if terminate:
                    return_val = reward
                else:
                    return_val = reward+agent.gamma*agent.value_q[state][act]
                agent.value_n[prev_state][prev_act] += 1
                agent.value_q[prev_state][prev_act] += (
                   (return_val - agent.value_q[prev_state][prev_act])/
                    agent.value_n[prev_state][prev_act]
                )
            prev_act = act
            prev_state = state
            state = next_state
            if terminate:
                break
    def policy_improve(self, agent):
        new_policy = np.zeros_like(agent.pi)
        for i in range(1, agent.s_len):
            new_policy[i] = np.argmax(agent.value_q[i,:])
        if np.all(np.equal(new_policy, agent.pi)):
            return False
        else:
            agent.pi = new_policy
            return True
    def sarsa(self, agent, env):
        for i in range(10):
            for j in range(2000):
                self.sarse_eval(agent, env)
            self.policy_improve(agent)
def td_sarse_demo():
    env = SnakeEnv(10, [3,6])
  
    np.random.seed(101)
    agent3 = ModelFreeAgent(env)
    td = SARSA(0.5)
    with timer('Timer sarse Iter'):
        td.sarsa(agent3, env)
    print('return_pi={}'.format(eval_game(env,agent3)))
    print(agent3.pi)
td_sarse_demo()

强化学习实例9:时序差分法(Temporal Difference)_第1张图片

你可能感兴趣的:(强化学习)