强化学习之基于gym环境的DQN算法实战(Pytorch)

DQN算法是强化学习与深度学习结合的开端,其利用深度网络来拟合值函数,利用Q-leraning算法进行强化学习。DL为Agent提供学习的大脑,RL提供了计算机制,从而达到真的AI。

之前有写过利用DQN算法去解决Cartpole任务和Mountaincar任务,具体可见强化学习之DQN算法实战(Pytorch):https://blog.csdn.net/MR_kdcon/article/details/109699297

上述的任务用的都是gym自带的环境,本文将参考经典环境Puckworld家用gym实现,并完成DQN算法。

 

 

实战内容:

Puckworld(冰球世界),简单地说就是Agent去追逐世界中随机出现的目标物体。

所需包:gym、torch

pytorch官网:https://pytorch.org/

gym官网:http://gym.openai.com/

开发环境:Pycharm

 

 

一、实际效果:

强化学习之基于gym环境的DQN算法实战(Pytorch)_第1张图片强化学习之基于gym环境的DQN算法实战(Pytorch)_第2张图片

绿色小球是我们的Agent,红色小球是我们的目标,他是会随机移动的。刚开始Agent并不理解这个陌生的环境,因此需要不断探索,在一次次完整序列之后,Agent逐渐理解,最后可以达到一看到目标,就会冲向他,并抓到目标获取最终胜利。

另外参考别人的设计,对Agent的颜色也进行了设置,离目标越远颜色越深,离目标越近颜色越浅。此外小球上的箭头表示当前Agent在状态上,执行动作的方向。

 

 

二、DQN算法:

DQN算法经历了初代版本的NIPS-DQN、本文将要介绍的Nature-DQN、DDQN等。

NIPS版本的DQN没有将Target-Q网络和Eval-Q网络分开来,导致TD目标值和Q预测值公用同一个网络,DQN算法是基于Q-learning的,回顾一下,Q-learning是将Q预测值以一个学习率α靠近TD目标值,从而最终达到Q预测值接近Q真值的结果。但NIPS中,当网络参数改变后,2个值同时改变,就像猫捉老鼠一样,Q预测值很难追得上变化的TD目标值,从而很难收敛,因此Nature-DQN就在这里进行改变,设置了一个隔段时间更新网络参数的Target-Q网络来制作TD目标值,Target-Q网络是个没有训练能力的网络,其参数只来自于Eval-Q网络。

 

NIPS-DQN:

强化学习之基于gym环境的DQN算法实战(Pytorch)_第3张图片

 

 

Nature-DQN:

1、初始化参数N(记忆库),学习率lr,\epsilon贪心策略中的\epsilon,衰减参数\gamma,更新步伐C

2、初始化记忆库D

3、初始化网络Net(本实验选择一个2层的网络,其中隐藏层为50个神经元的FC层),将参数W和W_指定为服从均值为0,方差为0.01高斯分布,并例化出2个相同的网络eval_net(以下简称Q1)和target_net(以下简称Q2)。

4、for episode in range(M):

5、     初始化状态observation

6、     for step in range(T):

7、            通过\epsilon贪心策略选中下个动作action,其中Q(s, a)来自于eval_net的前向推理。

8、            通过环境的反馈获得observation'以及奖励值reward。

9、            将observation、observation'、action、reward打包存入记忆库D中。

10、          当记忆库存满之后,抽取batch个大小的数据(sj, sj', aj, rj)送入2个网络中

11、          以损失函数L =  \frac{1}{batch}*\sum_{j=1}^{batch}( rj + \gamma * max(Q2(sj';W_)) - Q1(sj,aj;W))^2 进行训练

12、          每隔C步,更新target_net的参数W_ = W

强化学习之基于gym环境的DQN算法实战(Pytorch)_第4张图片

Note:

1、本实验中的网络采用了2个全连接层,用Pytorch实现。

2、Target-Q网络不参与训练,只是充当标签的作用,因此需要切断跟踪(即requires_grad=False),这样可以节约计算资源。

3、行为策略中的ε-greedy中动作的选取是根据Eval-Q网络选取的,因此这里也需要用detach或者data切断跟踪,以免干扰到网络的反向传播。

 

三、gym环境编写:

gym有2个核心的类:Env类,Space类。

参考gym包中各种环境的写法,那么如何模仿gym库来编写自己的环境呢?

step1:创建新式环境类并继承于gym.Env,然后最重要的是写出基础六大方法:__init__、seed、step、reset、render、close。这里要注意一下,一般step的输出为s1,r,is_done,info,s1一般都是numpy数组,其输入动作一般都是python数字。

step2:在这个基础上,还可以根据实际情况增加新的方法,比如用action如何控制状态啊这种。

step3:在主函数中导入新式环境,如果调用gym本身写好的环境,用env.make('Cartpole-v0')。

 

Note:这里说一下2个不常用但有时候会见到的2种用法

①:这里特别说一下这里的seed()方法,我们在gym自己编写的环境中总会看到seed这个方法,比如在CartPole-v0中:

def seed(self, seed=None):
    self.np_random, seed = seeding.np_random(seed)  #np_random是seeding模块中的一个函数
    return [seed]

这里的self.np_random是一个类的实例化对象,可以调用许多随机化的方法,比如uniform:

self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None
return np.array(self.state)

等效于random.uniform的作用。

再来说这个seed,这个seed是随机种子,和我们之前认知的一个道理,设置了相同的随机种子之后,下一步的随机化产生相同的结果,用完即消失。

比如还是这个CartPole环境,我们都知道初始化reset是随机位置:

env = gym.make('CartPole-v0')
env.reset()
for _ in range(1000):
    env.render()
    action = env.action_space.sample()  # 随机采样动作空间中的动作
    observation, reward, done, info = env.step(action)

    if done:
        observation = env.reset()
        print(observation)
env.close()

强化学习之基于gym环境的DQN算法实战(Pytorch)_第5张图片

用了env.seed(1)之后:初始位置就固定了,这就是seed()的用法。

强化学习之基于gym环境的DQN算法实战(Pytorch)_第6张图片

②:env.unwrapped方法

env = gym.make('MountainCar-v0')
env = env.unwrapped

这个方法是gym的基类Env中的unwrapped方法,功能是解除限制,比如我们的CartPole、MoutainCar不是有时间限制嘛,200步之内。那么这个东西这么一写,就相当于解除了这个限制。

下图是我绘制的gym常用的类和函数、方法:

强化学习之基于gym环境的DQN算法实战(Pytorch)_第7张图片

 

在Viewer里绘制一个几何图像的步骤如下:
1. 建立该对象需要的数据本身。
2. 使用rendering提供的方法返回一个Geom对象。
3. 对Geom对象进行一些对象颜色、线宽、线型、变换属性的设置,这其中有一个重要的属性就是变换属性,该属性负责对对象在屏幕中的位置、渲染、缩放进行渲染。
   如果某对象在呈现时可能发生上述变化,则应建立关于该对象的变换属性。该属性是一个Transform对象,而一个Transform对象,包括translate、rotate和scale三个属性,每个属性都由以np.array对象描述的矩阵决定。
4. 将新建立的Geom对象添加至viewer的绘制对象列表里,如果在屏幕上只出现一次,将其加入到add_onegeom()列表中,如果需要多次渲染,则将其加入add_geom()。
5. 在渲染整个viewer之前,对有需要的geom的参数进行修改,修改主要基于该对象的Transform对象。
6. 调用Viewer的render()方法进行绘制。

 

gym的常用写法:

self.viewer = rendering.Viewer(self.width, self.height)  # 创建窗口
target = rendering.make_circle(t_rad, 30, True) # 创建Geom小球对象,t_rad为半径,True为填充
(target = rendering.make_circle(t_rad, 30, False) # 创建Geom圆圈对象,t_rad为半径,False为不填充)
target.set_color(0, 0, 0)  # 设置Geom对象的颜色(0,0,0)为黑色   (1,0,0)红色 (0,1,0)绿色 (0,0,1)蓝色 (1,1,1)白色
self.target_trans = rendering.Transform()  # 产生Transform类的对象
target.add_attr(self.target_trans)  # 设置变换对象作为Geom对象的属性
self.viewer.add_geom(target)  # Geom对象添加至viewer的绘制对象列表里
target_trans.set_translation(x, y)  # 修改变换对象的位置
target_trans.set_rotation(degree / RAD2DEG)  # 修改变换对象的角度(输入必须为弧度制)
target_trans.set_scale(x, y)  # 修改变换对象的大小
env.render()  # 进行调用,绘制环境

Note:

1、在窗口中,以窗口左下角为(0, 0)。

2、常用的还有画三角形(这个注意一下,gym的三角形是有一个支点的,便于控制旋转)、矩形、线、点,具体操作和上述画圆类似。

绘制结果:

强化学习之基于gym环境的DQN算法实战(Pytorch)_第8张图片

 

四、任务结果分析:

冰球世界环境:

动作空间:[0,1,2,3,4]分别代表向左、向右、向上、向下、不动。

状态空间:每一个状态都是一个(6, )格式的Array数组,如下所示:

def reset(self):
    self.state = np.array([self._random_pos,  # Agent的x坐标
                           self._random_pos,  # Agent的y坐标
                           0,                 # Agent的x方向速度,向右为正
                           0,                 # Agent的x方向速度,向左为正
                           self._random_pos,  # 目标的x坐标
                           self._random_pos   # 目标的y坐标
                           ])

实时奖励设置:当Agent与目标触碰时,奖励为0;非触碰时,奖励与距离的负数成正比。

完成信号done:当Agent和目标触碰时,输出终止信号Done,准备开启下个episode。

另外,目标每隔100steps就会随机更新位置。

 

强化学习之基于gym环境的DQN算法实战(Pytorch)_第9张图片

如上图所示,这是训练500个episode后的结果,上图第一张图的蓝色部分是每个episode所需的步数,橘色为累计奖励(无衰减)。上图第二张图片为loss的损失。

从橘色线条可以明显看出,Agent对目标的搜寻在几个episode之后就可以达到很低的负奖励,意味着Agent能及时追踪到目标。说明DQN是有效果的。

 

强化学习之基于gym环境的DQN算法实战(Pytorch)_第10张图片

上图是DQN训练之后的测试过程,通过将训练之后的网络参数提取出来,然后加入当测试网络中。测试过程直接使用贪心策略,不需要训练网络,也不用Target-Q网络,直接使用Eval-Q网络进行Q值的近似。

从图中可以看出,Agent可以直接搜寻到目标,且几乎保持稳定,说明了DQN效果比较显著。

 

五、总结:

1、gym是一种学习RL的很不错的GUI环境,可以通过自己编写或者调用gym官网提供的环境进行自己的RL设计。

2、DQN在Q-learning的基础上,结合DL进行值函数近似。DQN伪代码和Q-learning的很像,唯一不同的是,Q-learning存储Q值是通过Q表(或Q-字典,通常Q字典的速度更快),DQN是通过值函数近似,用Q网络输出的值表示Q值。

3、Q-learning的缺陷时无法很好应对连续状态或很多离散状态下的RL问题,而DQN的缺陷时过估计问题,就是其Eval-Q输出的Q值是高于其Q真实值的,根本原因在于Q值更新的公式中,利用Target-Q输出的最大值对应的动作为a',具体的可以查看我的下一篇文章。

 

 

 

你可能感兴趣的:(算法,python,人工智能,深度学习,强化学习)