强化学习入门(三)将神经网络引入强化学习,经典算法 DQN

本文内容源自百度强化学习 7 日入门课程学习整理
感谢百度 PARL 团队李科浇老师的课程讲解

文章目录

  • 一、为什么要引入神经网络
  • 二、DQN 算法
    • 2.1 DQN 约等于 Q-learning + 神经网络
    • 2.2 DQN 的两大创新
      • 2.2.1 经验回放 Experience replay
      • 2.2.2 固定 Q 目标 Fixed Q target
    • 2.3 DQN 流程框架图
    • 2.4 PARL 的 DQN 框架
  • 三、DQN 算法代码详解

一、为什么要引入神经网络

Q 表只能解决少量状态的问题,如果状态数量上涨,那我们面对的可能性呈现指数上涨,这样的话Q表格就没有这个处理能力了

比如:

  • 国际象棋: 1 0 47 10^{47} 1047种状态
  • 围棋: 1 0 170 10^{170} 10170种状态
  • 连续操作的问题:不可数状态(不如弯曲角度)
  • (整个宇宙的原子数量预估: 1 0 80 10^{80} 1080

Q表格不行的时候,我们可以采用:值函数(Q函数)近似

Q表格的作用在于:输入状态和动作,输出Q值

那我们可以用一个 “带参数” 的 Q 函数来进行替代: q π ( s , a )   ≈   q ^ ( s , a , w ) q^π(s,a)\ \approx\ \hat{q}(s,a,\textbf{w}) qπ(s,a)  q^(s,a,w)

  • 多项式函数
  • 神经网络

不同的近似方式:

  • 输入状态 s 和动作 a,输出一个q 值 q ^ ( s , a , w ) \hat{q}(s,a,\textbf{w}) q^(s,a,w)
  • 输入状态 s,输出多个 q 值(不同动作所对应的 q 值) q ^ ( s , a 1 , w )   . . .   q ^ ( s , a m , w ) \hat{q}(s,a_1,\textbf{w})\ ...\ \hat{q}(s,a_m,\textbf{w}) q^(s,a1,w) ... q^(s,am,w)

Q表格方法的缺点:

  • 表格占用大量内存
  • 表格大的时候,查表效率低

值函数近似的优点:

  • 仅需存储有限数量的参数
  • 状态泛化,相似的状态可以输出一样

神经网络可以逼近任意连续的函数

  • 比如 CNN 在强化学习中引入后,可以让强化学习算法根据图片做出决策(输入图片,输出动作)
  • 神经网络的原理在于,定义 cost 为真实值和预测值之间的差距,然后用梯度下降来最小化 cost

二、DQN 算法

DQN 是使用神经网络解决强化学习问题最经典的算法

该算法由谷歌的 DeepMind 团队在 2015 年提出

《Human-level control through deep reinforcement learning》这篇论文被发表在了 Nature 杂志上

通过高维度的输入信息(像素级别的图像),使用了神经网络的 DQN 在 49 个 Atari 游戏中,有 30 个超越了人类水平

使用神经网络代替Q表格以后:

  • 输入可以是一个向量,包含各种值(比如四轴飞行器的高度,角度,转速等)
  • 输入可以是一个图片,包含各个像素点的信息
  • 输出直接是对应的动作

2.1 DQN 约等于 Q-learning + 神经网络

  • 输入 状态 s
  • 输出 q 向量,如果一个状态下有 5 种动作,那 q 就是 5 维的
  • 然后根据我们具体的动作选择,确定 q 值
  • 然后要让输出的 q 值,逼近 目标 q 值(target_q)
    • target_q 的计算公式就是 Q-learning 的方法: q π ( s , a )   =   r   +   γ max ⁡ a ′ q ^ ( s ′ , a ′ , w ) q_π(s,a)\ =\ r\ +\ γ\max\limits_{a'}\hat{q}(s',a',\textbf{w}) qπ(s,a) = r + γamaxq^(s,a,w)
    • 神经网络输出的预测值: q ^ ( s , a , w ) \hat{q}(s,a,\textbf{w}) q^(s,a,w)
    • 计算预测值和目标值的均方差(即 loss): E π [ ( q π ( s , a )   −   q ^ ( s , a , w ) ) 2 ] E_π[(q_π(s,a)\ - \ \hat{q}(s,a,\textbf{w}))^2] Eπ[(qπ(s,a)  q^(s,a,w))2]
  • 使用优化器,最小化 loss
    强化学习入门(三)将神经网络引入强化学习,经典算法 DQN_第1张图片

2.2 DQN 的两大创新

神经网络中由于引入了非线形函数,比如 “relu”

所以在理论上,无法证明训练之后一定会收敛

于是 DQN 提出两大创新,使得训练更有效率,也更稳定

2.2.1 经验回放 Experience replay

作用:

  • 解决序列决策的样本关联性问题
  • 解决样本利用率低的问题

问题来源:

  • 在监督学习中,训练样本是独立的
  • 但是在强化学习中,输入的是状态值,每一个状态都是连续发生,前后状态相互关联,所以样本之间具有关联性

解决方案:

  • 需要打乱,或者切断输入样本之间的联系
  • 这里用到了 Q-learning 的 Off-Policy 特点
  • 先存储一批经验数据
  • 然后打乱
  • 从中随机选取一个小的 batch 来更新网络
  • 这样就打破了样本间的相关性,同时使得网络更有效率

Off-Policy 在经验回放中的作用:

  • 设置经验池:是一个固定长度的队列
  • 一条经验指的是:一组 s t s_t st a t a_t at r t + 1 r_{t+1} rt+1 s t + 1 s_{t+1} st+1
  • 每拿到一条经验就往经验池进行存储
  • 满了以后,弹出旧的经验
  • 从经验池中随机抽取一个 batch
  • 去更新 Q 值(这里就是更新神经网络的系数)

强化学习入门(三)将神经网络引入强化学习,经典算法 DQN_第2张图片

优点:

  • 由于经验池中的数据有可能被重复抽取到,所以相当于经验可以重复利用,即提高了样本的利用率
  • 另外由于是随机抽取,所以打乱了样本间的相关性

2.2.2 固定 Q 目标 Fixed Q target

作用:

  • 解决算法训练不稳定的问题

问题来源:

  • 监督学习中,我们预测值要去逼近真实值,而真实值是固定不变的
  • 但是在 DQN 中,输入状态输出预测的Q,要逼近的是目标Q
  • Q _ t a r g e t   =   r   +   γ   m a x   Q ( s ′ , a ′ , θ ) Q\_target\ =\ r\ +\ γ\ max\ Q(s',a',θ) Q_target = r + γ max Q(s,a,θ)
  • 其中 m a x   Q ( s ′ , a ′ , θ ) max\ Q(s',a',θ) max Q(s,a,θ) 也是神经网络的输出,而神经网络权重系数一旦更新以后,这个值也会发生变化
  • 所以只要我们更新一次神经网络,那目标 Q 值也就会不断变化

解决方法:

  • 我们要想办法把 Q-target 值固定住
  • 也就是我们要把输出 Q-target 的神经网络参数固定一段时间
  • 然后过一段时间以后,再用最新的学习后的神经网络参数,刷新这个神经网络

2.3 DQN 流程框架图

强化学习入门(三)将神经网络引入强化学习,经典算法 DQN_第3张图片
Model:

  • 代替了 Q 表
  • 输入 S 输出 不同动作对应的 Q(预测值)给 Agent
  • 同时设定一个固定一段时间的神经网络用于输出 Q_target
  • 过一段时间更新该固定网络参数

引入神经网络的问题解决:

  • 经验回放
  • 固定目标值

Agent:

  • 和环境交互
  • 交互数据(经验)存储到经验池
  • 提取经验池数据,更新 Model 参数(利用最小化 预测值和目标值之间的 loss)——DQN最核心部分

2.4 PARL 的 DQN 框架

强化学习入门(三)将神经网络引入强化学习,经典算法 DQN_第4张图片
分为 model,algorithm,agent 这 3 个部分

  • model:用来定义神经网络部分的网络结构,同时实现模型复制
  • algorithm:实现具体算法,如何定义损失函数,更新 model,主要包含了 predict() 和 learn() 两个函数
  • agent:负责和环境做交互,数据预处理,构建计算图

强化学习入门(三)将神经网络引入强化学习,经典算法 DQN_第5张图片
总体抽象来说:

  • Agent 包含了 Algorithm 和 Model
  • Algorithm 包含了 Model

PARL 常用的 API:

  • agent.save():保存模型
  • agent.restore():加载模型
  • model.sync_weights_to():把当前模型的参数同步到另一个模型去
  • model.parameters():返回一个 list,包含模型所有参数的名称
  • model.get_weights():返回一个 list,包含模型的所有参数
  • model.set_weights():设置模型参数

PARL 里面打印日志的工具:

  • parl.utils.logger:打印日志,涵盖时间,代码所在文件及行数,方便记录训练时间

PARL 的 API 文档地址:

https://parl.readthedocs.io/en/latest/model.html

三、DQN 算法代码详解

强化学习算法 DQN 解决 CartPole 问题,代码逐条详解

你可能感兴趣的:(强化学习,神经网络,强化学习,人工智能,百度)