[经典论文分享] Decision Transformer: Reinforcement Learning via Sequence Modeling

1 背景

无聊时看群聊发现在半年前2021年7月左右新出了一个方法,叫做decision transformer。一直以来都是对attention机制大家族保持着崇高的敬意,于是找到了这篇文章看了一下。看完之后感觉并不是很惊喜,也可能是期待太高。文章核心做的工作是给出了一种新的深度强化学习训练模式,使得能够更加‘端对端’地去用transformer大家族去拟合和训练。截止2022年1月22日,这篇文章在谷歌学术上有了50次引用(半年多)。
论文原文:Decision Transformer: Reinforcement Learning via Sequence Modeling
代码仓库:https://github.com/kzl/decision-transformer

2 模型结构

[经典论文分享] Decision Transformer: Reinforcement Learning via Sequence Modeling_第1张图片
文章并没有提出新的模型结构,本质是在为transformers提供输入的embeddings。RL中的一个轨迹由多个 s t , a t , r t s_t, a_t, r_t st,at,rt按顺序组成,作者将其中的 r t r_t rt换成对于未来的奖励期望–汇报 R t R_t Rt,然后将采样获得的长度为K的轨迹直接顺序拼接起来形成一个输入。然后就可以将其看作正常DL中的 x i x_i xi来做回归训练了。那LOSS是怎么产生的呢?因为整个模型要预测的是下一步动作,所以LOSS是由当前模型预测的下一步的动作 a p r e a^{pre} apre和真实的下一步动作 a a a之间的差得到的。这就很有意思了,相当于标签不仅是标签,还会在某种情况下成为输入。
当然,这里面还是有几个小细节:
1)直接把采样得到的K长的raw feature喂给transformers那怕是有点直接,于是作者在这二者直接加了一层MLP来project一下。如果状态是图像的话,比如Atari里的游戏,那么就通过CNN提取后再拼接回去。
2)虽然作者抛弃了传统的策略改进过程,但是序列决策问题还是序列决策,总是要有个能表示当前步骤的 t t t的,因此作者在 s t , a t , r t s_t, a_t, r_t st,at,rt的embeddings上都加了 t t t的embedding,相当于了positional encoding。
3)作者想用transformers拟合的是动作action,但是作者说其实拟合state和reward也行就是没那么直接。

下面这个伪代码还是很直观的:
[经典论文分享] Decision Transformer: Reinforcement Learning via Sequence Modeling_第2张图片

3 实验

实验部分,作者在Atari和openai的数据集上做了测试,言简意赅概括就是“还行,还不错”:
[经典论文分享] Decision Transformer: Reinforcement Learning via Sequence Modeling_第3张图片
作者的几个小实验放在了discussion里,里面有几个有意思的尝试,挑了两个:
[经典论文分享] Decision Transformer: Reinforcement Learning via Sequence Modeling_第4张图片
一个是说K,也就是采样的长度越长越好。
[经典论文分享] Decision Transformer: Reinforcement Learning via Sequence Modeling_第5张图片
另一个是说解决稀疏、延迟奖励环境下效果也不错。

4 特点总结

1)总的来说,本文还是给DRL领域带来了有趣的尝试,尤其是对于离线DRL来说,能够通过简单的输入构建就可以利用上Transformer大家族的强大模型,比如GPT\BERT,这对于一些问题还是十分重要的。
2)仔细想想,效果可能主要来自于注意力机制对于样本之间的信息交互作用,使得不同样本之间学习到了一下未来或者过去的知识,从而可以直接端对端学习动作。我们想象一个生动的场景:每隔1小时就会复制一个你,然后放到小黑屋里存着,你还是正常做事情。那么过了10个小时,有了10个你的样本,这些样本都知道到他们产生时刻为止事情的一些进展,你让他们来到一起交流分享一下,那么就会从这些不同时刻的片段你中学到到底什么是对的什么是错的。之所以作者会用回报而不是即时奖励,就是因为回报是能代表当前时刻的一个优劣的情况,以方便不同的样本之间进行交互。

你可能感兴趣的:(RL-based文献阅读,神经网络基础模型关键点,transformer,深度学习,强化学习)