【论文翻译】One-Shot Imitation Learning

这篇论文看的想爆炸了。。context network实在是看不懂。。。。写了一半暂时放弃,有缘再回来继续嚼


  • Abstract

    • 理想的情况是:agent可以从“关于给定任务的少量demonstration”中进行学习,并且泛化到相同任务的新情况,并且不需要特殊的工程。假设有一个任务集合(用桌上的木块搭建一个塔/用桌上的木块搭建两个塔),每个任务有许多实例(不同的实例意味着:木块具有不同的初始状态)。训练时,神经网络的输入为:“第一个demonstration”和“采样自第二个demonstration的state序列”,输出:states序列对应的action。测试的时候,给出一个关于new task实例的demonstration,希望训练出的神经网络能在这个new task的新实例上表现得很好。我们的期望是:通过大量不同的task的训练,我们可以把demonstration转变为鲁棒的policy,从而完成各种各样的任务。
  • 1.Introduction

    • demonstration是一种方便的信息形式,可以交流关于task的操作,提供特殊的细节,相比于语言更好用。迄今为止的imitation learning都是特殊的特征工程,而我们期望的是:向agent演示几次,agent就可以泛化到相同任务的新情况,不需要很长的系统交互时间。我们的policy的输入时:(1)当前observation (2)能够解决相同任务的不同实例的demonstration。输出当前的control。
  • 2 .Related Work

  • 3. One Shot Imitation Learning

    • 3.1 Problem Formalization

      • 任务分布T,从任务分布中采样一个任务t~T
      • 关于任务t的demonstration的分布为D(t),策略是Pi theta(a|o,d),其中a是动作,o是observation,d是一个demonstration,theta是策略参数。
      • 从demonstration中采样一个demonstration :d~D(t),d是一个“由observation和action组成的序列”:d=[(o1,a1),(o2,a2),......,(oT,aT)]。
      • 我们假设任务分布T是给定的,并且我们可以从任务中获取成功的demonstration。
      • Rt(d)是一个标量评估函数。我们的目标是最大化策略的性能表现。
    • 3.2 Block Stacking Tasks

      • 我们用来举例的任务是:使用一个7自由度的机器人,把数量可变的方块搭成塔。
    • 3.3 Algorithm

      • 为了训练策略神经网络,我们使用imitation learning algorithms比如说behavioral cloning和DAGGER,优点在于只需要demonstration,不需要人为精心设置的奖励函数。前者相对来说容易实现得多。
      • 我们对每个任务收集一些demonstration,同时给动作添加噪声,为了获取更广泛的轨迹空间。我们训练策略,使得agent在给定current observation和demonstration的时候,能输出我们希望输出的action。
  • 4. Architecture

    • 我们希望训练一个通用的神经网络,输入demonstration和current observation,输出合适的action。本文的主要贡献就是一种学习搭积木的网络架构。架构有三个部分组成:demonstration network,context network,manipulation network。
    • 4.1 Demonstration Network

      • 网络的输入:demonstration trajectory,输出:demonstration的embedding。产生的embedding会供policy使用,embedding的长度随着demonstration长度和block数量的增加而线性增长。
      • Temporal Dropout:

        • 对于搭积木任务,demonstration的长度会达到上千,训练很消耗算力。因此在训练期间,我们随机扔掉demonstration的子集,扔掉p%。在测试期间,我们使用经过downsample的轨迹。
      • Neighborhood Attention:

        • 对经过downsample的轨迹,我们对其使用dilated temporal convolution 和 neighborhood attention。
        • 我们的神经网络需要能够处理具有木块数量可变的demonstrations,因此架构里面必须有能够处理可变长输入的模块。soft attention是一种:把变长输入映射到定长输入的操作。但是,直接使用soft attention势必会丢失信息。我们需要的是一种操作:可以把变长输入映射到变长输出。
        • 简单的介绍soft attention:
          • 输入:a query q, a list of context vectors cj , a list of memory vectors mj。q关于第i个context vector的attention weight为:
          • 其中v是学习到的权重。输出的是:memory关于attention weight的加权(对wi做了softmax处理):
          • 注意,输出的维度和memory vector的维度相同。这种attention操作可以泛化到n个query head的情况,此时也会有n个输出的结果。
        • Neighborhood Attention:

          • 假设环境中中B个木块。机器人的状态记为s robot。每个木块的坐标记为:(x1,y1,z1),(x2,y2,z2),....,(xB,yB,zB)。

          • neighbor attention的输入为:embedding列表 h1in,h2in,.....,hBin。使用两个独立的线性层,使用embedding来计算query vector和context embedding:

          • memory由“木块的坐标,串联,输入的embedding h”组成。第i个query的查询结果,是由一个soft attention生成的:

          • 【论文翻译】One-Shot Imitation Learning_第1张图片

          • 直觉上,Neighborhood Attention这个操作允许每个木块i查询和它相关的其他木块,并且提取出查询的信息也就是result i。在每个时间步,第i个木块查询到的result,结合第i个木块自己的信息,产生输出的embedding。具体的,也就是:

          • 实际中,我们使用的是多个query head,所以result的长度正比于query head的数量。

      • Demonstration Network实现细节:

        • 假设demonstration有t个时间步,有B个木块。此架构只使用demonstration当中的observation,每个observation是维度为3*B+2的向量,包含B个木块相对于机器人抓手的(x,y,z)坐标,剩余2个向量用来表示机器人的抓手是关闭还是打开。
        • 【论文翻译】One-Shot Imitation Learning_第2张图片
        • 输入一个完整的d,然后使用temporal-dropout,得到d'。然后把observation分成block_state和robot_state。机器人的状态被广播到每个木块中,所以block_state的维度是:Tnew * B* 3 ,robot_state的维度是:Tnew * B* 2。
        • 先对每个时间步中,每个block的block_state进行1*1卷积,作为每个block的embedding,记为h。
        • 对每个时间步,每个block对其同一时间步的block都进行一次Neighborhood Attention。把查询到的result,当前block的state,robot的state串联起来,进行1*1卷积,得到h‘。
        • 把h和h’直接相加,最终输出demonstration embedding,维度是Tnew * B* D。
    • 4.2 Context network

      • context network是架构中最重要的部分,输入:current state和demonstration embedding(Demonstration Network生成的embedding),输出:context embedding。注意,context embedding的长度不依赖于demonstration的长度,也不依赖于block的数量。因此,context network只捕捉相关的信息,并且传入到manipulation network。
      • Attention over demonstration:

        • context network先计算输入的current state的query vector。然后用query vector查询不同时间步的demonstration embedding。把相同时间步内的权重相加,可以得到一个向量(长度和时间步长度相关)。【这个地方没看懂】

你可能感兴趣的:(机器学习)