Meta Network||论文笔记

 

元学习论文总结||小样本学习论文总结

2017-2019年计算机视觉顶会文章收录 AAAI2017-2019 CVPR2017-2019 ECCV2018 ICCV2017-2019 ICLR2017-2019 NIPS2017-2019

 

Meta-Learning论文笔记:Meta Network

Meta Network||论文笔记_第1张图片

本文是对2017年ICML的一篇Meta-Learning论文的笔记论文连接。

MetaNet 是Meta Networks的缩写,具有用于跨任务快速泛化的体系结构和训练流程。

名词说明:Fast Weight 和 Slow Weight

模型的跨任务快速概括依赖于fast weight。神经网络中的参数通常是根据目标函数中的梯度下降来更新的,这个过程对于小样本学习是很慢的。一种更快的学习方法是利用一个神经网络预测另一个神经网络的参数,生成的参数称为快权值即fast weight。普通的基于SGD等优化的参数被称为慢权值即slow weight。

在 MetaNet 中,损失梯度信息被作为meta information ,用来生成快权重。在神经网络中,将慢权值和快权值结合起来进行预测。

Meta Network||论文笔记_第2张图片

多层叠加Layer Augmentation

模型:

整体架构:

如图,MetaNet的训练包括三个主要过程: meta information的获取、以及fast weight的生成和slow weight的优化,由base learner和meta learner共同执行。

Meta Network||论文笔记_第3张图片

MetaNet的整体结构

数据集和主要的函数说明:

  • 训练数据包含两种数据集:支持集 和训练集 
  • Base learner简写为 b,是一个函数或神经网络。通过任务损失  估计主要任务目标。它的参数由慢权值 example-level的快权值  构成
  • 动态表征函数 u,对样本学习到一个嵌入。参数由慢权值 example-level快权值  组成
  • Meta learner由快速权值生成函数 和 组成,参数为 和 G,它们的输入由损失梯度 和  构成,经过映射后生成 和 和其对应慢权值维度相同

训练过程:

1. 表征函数的学习:将随机采样的支持集数据输入到表征(嵌入)函数 中,为了得到数据集的嵌入,利用表征损失  来捕获表示学习目标,并将梯度作为meta information获取。其中损失函数为:它的具体计算是随机抽取 对支持集样本的来计算嵌入损失:其中  是辅助标签:其实也就是个二分类,属于所有的支持集样本嵌入做距离计算后经过映射或  函数转化为概率,就成为二分类问题。每次任务损失反向传播得到其损失梯度信息:对函数 每次任务损失反向传播得到其梯度信息  ,通过快权值生成函数 的映射得到快权值  :2. 快权值的生成:对每个支持集样本数据输入到Base learner函数 中,之后计算出预测的标签和支持集实际的标签通过交叉熵等损失函数计算  :生成Base learner 的快权值需要支持集的meta information,即利用支持集的损失梯度信息:函数 从损失梯度  中学到一个映射,映射后得到快权值  :

这个快权值  存储在  中。

3. 建立支持集的索引:利用参数为快权值  和慢权值  的表征函数 支持集进行建立索引(有快权值的支持集的嵌入)  :4. 建立训练集的索引:与上一步类似,通过具有慢权值和快权值的表征函数 训练集建立查询索引(对训练集的嵌入):5. 对快权值的读取:如果参数  存储在  中且索引  已经建立,用attention(这里的attention用余弦相似度计算存储索引和输入索引)在之前建立的所有支持集的索引  和每一个训练集的索引计算一个相似分数:然后经过归一化后用于读取存储  得到最终的快权值:6. 训练集标签的预测:Base learner函数 有了慢权值  和快权值  后那么执行one-shot分类为:这里的  是对  的预测输出,另外这里的输入也可以用训练集的嵌入  代替。最终训练集损失的计算:整个网络的训练参数是  ,通过像反向传播算法去最小化任务损失 。

MetaNet的训练算法如图所述:

Meta Network||论文笔记_第4张图片

MetaNet论文中的算法

论文在Omniglot、Mini-ImageNet 和 MNIST 三种数据集上做了One-Shot实验,实验结果都不错,具体可以看一下论文。

总结:该模型利用损失梯度作为元信息来计算快权值,能够快速适应新的不同的任务,增强在训练样本少的情况下的学习效果。效果其实也不是很强,有很多可以改善的点,并且具体训练的时候因为生成快权值的神经网络参数较多或用的LSTM这样的网络所以比较慢。

你可能感兴趣的:(深度学习,元学习)