模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL

1 GAN (回顾)

GAIL 的设计基于生成判别网络 ( GAN)。这里简单地回顾一下GAN,详细的可见 NTU 课程笔记 7454 GAN_UQI-LIUWJ的博客-CSDN博客
GAN由生成器 (Generator) 和判别器 (Discriminator)组成,它们 各是一个神经网络。
——>生成器负责生成假的样本
——>判别器负责判定一个样本是真是假。
我们的目标是希望生成器生成的内容可以“以假乱真”

1.1 生成器

        生成器 记作 a = G ( s ; θ ) ,其中 θ 是参数。它的输入是向量 s ,向量的每一个元素从均匀分布U(-1,1)或标准正态分布 N (0 , 1) 中抽取。生成器的输出是数据(比如图片)x 模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第1张图片

 

1.2 判别器

判别器 记作\hat{p}=D(x;\phi),其中 ϕ 是参数。
它的输入是图片 x;输出 \hat{p} 是介于 0 1 之间的概率值,0 表示“假的”, 1 表示“真的”。
判别器的功能是二分类器。
模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第2张图片

1.3 训练生成器

        将生成器与判别器相连,固定住判别器的参数,只更新生成器的参数 θ,使得生成的图片 x = G(s; θ) 在判别器的眼里更像真的。

        对于任意一个随机生成的向量 s,应该改变 θ,使得判别器的输出\hat{p}=D(x;\phi)尽量接近 1

        可以用如下函数作为loss function:

 

         我们希望此时D(x;Φ)越大越好,也就是E(s;θ)越小越好

        所以我们用梯度下降来更新生成器的θ

1.4 训练判别器

  •  判别器的本质是个二分类器,它的输出值 \hat{p}=D(x;\phi)表示对图片真伪的预测;
    • \hat{p} 接近 1 表示“真”,
    • \hat{p}接近 0 表示“假”。
判别器的训练如下图所示。
  • 从真实数据集中抽取一个样本,记作x^{real}
  • 再随机生成一个向量 s,用生成器生成 x^{fake}=G(s;\theta)
  • 训练判别器的目标是改进参数 ϕ,让 D(x^{real};\phi) 更接近 1(真),让D(x^{fake};\phi)更接近 0 (假)。
  • ——>也就是说让判别器的分类结果更准确,更好区分真实图片和生成的假图片。

此时的损失函数如下所示

 不难发现,判别器越准确,损失函数F越小

所以我们也用梯度下降更新判别器的θ

 

 1.5 整体训练流程

  • 模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第3张图片
  • 模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第4张图片

 

 2 生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL

2.1 训练数据

GAIL 的训练数据是被模仿的对象(人类专家)操作智能体得到的轨迹

 

数据集中有 k 条轨迹,把数据集记作:

 

 2.2 生成器

GAIL 的生成器是策略网络 π ( a | s ; θ )
策略网络的输入是状态 s,输出是一个向量:

 

输出向量 f 的维度是动作空间的大小 A ,它的每个元素对应一个动作,表示执行该动作
的概率。
给定初始状态 s 1 ,并让智能体与环境交互,可以得到一条轨迹:

 

 其中动作是根据策略网络抽样得到的, a_t \sim \pi(\cdot|s_t;\theta), \forall t=1,\cdots, n

 下一时刻的状态是环境根据状态转移函数计算出来的

 模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第5张图片

 2.3 判别器

GAIL 的判别器记作 D ( s, a ; ϕ )

判别器的输入是状态 s,输出是一个向量:

 

输出向量 \hat{p} 的维度是动作空间的大小 A ,它的每个元素对应一个动作 a ,把一个元素记作:

 

\hat{p_a}接近 1 表示 ( s, a ) 为“真”,即动作 a 是人类专家做的。
\hat{p_a}接近 0 表示 ( s, a ) 为“假”,即动作 a 是策略网络生成的。
模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第6张图片

 2.4 GAIL的训练

2.4.1 训练生成器

\theta_{now}是当前策略网络的参数。用策略网络\pi(a|s;\theta_{now})控制智能体与环境交互,得到一条轨迹:
用判别器评价 (s_t,a_t)的真实情况, D(s_t,a_t;\phi)越大,说明 (s_t,a_t)在判别器的眼里越真实。
我们记第t步的回报为:

 

于是我们的轨迹可以变成

 

 有不同的方法来更新策略网络的参数θ

在GAIL中,使用的是TRPO

 强化学习笔记:置信域策略优化 TRPO_UQI-LIUWJ的博客-CSDN博客

即目标函数为

通过解带约束的最大化问题,得到新的参数

 

 2.4.2 训练判别器

训练判别器的目的是让它能区分真的轨迹与生成的轨迹
我们从训练数据中抽样一条轨迹:

同时用策略网络控制智能体和环境交互,得到另一条轨迹,记作

 注意real和fake轨迹的长度可能不一样

同样地,我们希望D(s_t^{real},a_t^{real};\phi)尽量趋近于1,D(s_t^{fake},a_t^{fake};\phi)尽量趋近于0

于是我们定义损失函数

模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第7张图片

 我们希望损失函数尽量小,也就是说判别器能区分开真假轨迹。可以做梯度下降来更新判别器的参数Φ

 

2.4.3 整体训练流程

每一轮训练更新一个生成器,更新一次判别器。训练重复以下步骤,直 到收敛。
模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第8张图片

 模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL_第9张图片

 

你可能感兴趣的:(强化学习,强化学习)