Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析

Crowd-Robot Interaction:Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析

Crowd-Robot Interaction:Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解读

近期精读了一篇强化学习论文,在此分享一下,相互学习。

论文亮点

逾越人机交互到人群与机器人交互.

  1. 使用attention机制,重新定义了人与机器人之间的交互对;
  2. 在强化学习框架中包含了人机交互和人人交互的联合模型;

问题模型

这篇文章中,机器人考虑穿过n个人的导航任务,这个场景被描述成一个强化学习序列决策问题。每一个智能体或机器人都有自己的位置速度信息以及目标点的位置和偏好速度。强化学习中的state是一个联合state里面包含机器人的state和环境中人的state。最优策略期盼得到最优的回报return,表述如下:
Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析_第1张图片
Rt表示在t时刻所获得的奖励,V星是最优价值函数,Vpref是折扣系数中的正则项。
reward函数如下:
Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析_第2张图片
dt是机器人和人之间的最短距离。

方法

文章中估计V值的时候用了一个神经网络,与其他DQN不同的是他利用Self-attention mechanism 把环境中的一个人和其他人的相关重要性结合起来。这样模型不仅考虑了环境中的每一个人的state同时也考虑了其他人对环境中某一个人的影响,这与人的下一步动作有关。

从环境中获得的参数有:
Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析_第3张图片
S是机器人的state,Wi是第i个人的state。dg是机器人到目标点的距离,di是机器人到第i个人的距离。实际还需要设置一个超参数theta。这样state一共有13个,但是这些state是不能包含环境中的其他人对某一个的影响的信息的,而这个信息他是通过建格子的方式获得的。具体操作如下:
Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析_第4张图片
Hi表示第i个人,以第i个人为中心以机器人运动方向为x轴,建16个11m的格子,看环境中的其他人是否在这16个小个子中,如果在的话还需要把这些人的速度vx和vy记录下来。建立一个163的数组,把数组中48个元素分成16组每组表示一个格子,每组有三个元素第一个元素表示这个格子中是否有人,如果有人第二、三个元素表示速度vx、vy,如果格子中没人则三个元素都置零。用这48个state表示环境中某一个人被环境中其他人影响的信息,这样联合state为13+48=61个。在同一时刻,机器人的state不变,环境中每个人的state是不同的,假设环境中有五个人因此输入模型中时的state也是61*5个(batch size=100)。神经网络输入是500×51的矩阵,输出是100×1的矩阵,表示状态值V。

Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析_第5张图片

self-attentionCrowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析_第6张图片

a1、a2、a3、a4表示机器人和第i个人的joint_state(61),如果直接输入网络中,我们不知道他们之间的关系,Self-attention 可以帮助我们获得它们之间的关系。self-attention中关键的三个参数QKV,是由输入向量分别乘三个矩阵(可以是神经网络)得到。比如我想知道第一个人与环境中人的关系,分四步:1、求q1,qi=Wq×a1
2、求K,K= Wk×A
3、计算q1与K的相关性
4、把q1与K的相关性参数进行softmax操作

第三步中常用计算相关性方法有两种,Dot-product更为常用。
Crowd-aware Robot Navigation with Attention-based Deep Reinforcement Learning 论文解析_第7张图片
但这篇文章中使用的是第二种additive的方法,同时他把tanh函数省略了。并且在计算q的时候,作者把所有人的q值求了一个均值得到q‘,他假设了一个人,这个人的参数由环境中所有人提供,得到的相关性是这个假想的一个人与其他人的相关性。相关性得到之后是一个0-1的数值,相关性与V进行点乘(V 是通过A乘一个矩阵得到的),点乘之后的结果是考虑了其他人对”均值人“的影响之后的一个矩阵。之后在N个人维度上进行线性相加,得到人群的表述。

代码框架

代码中分为两部分
1、imitation learning
2、reinforcement learning

未完待续

你可能感兴趣的:(paper,reading,自动驾驶,pytorch,深度学习,强化学习,人工智能)