元学习基础理解

元学习基础理解

  • 简介
    • 基础概念
    • 损失函数
    • 总体框架
  • MAML方法
    • MAML好的原因
    • 公式简单说明
    • 更新方式
  • RNN学习梯度更新
  • 小结

简介

最近看了李宏毅老师的元学习课程,在这里记录一下学习内容。元学习的目标可以定义为,如何通过学习大量任务,以实现处理新任务时可以完成快速学习。

基础概念

元学习的基础框架如下图所示,基本概念可以理解为:
(1) 在深度学习中,我们的目标是拟合学习一个分类器或回归器f,也就是先设定好神经网络模型架构,然后不断通过反向传播训练f直到收敛。在学习过程中主要使用数据集中的训练集进行训练,并使用测试集进行评估。
(2)在元学习中,我们的主要思想是学习如何去学习,也就是我们要学习一个F,F可以获得最终学习器f。训练时使用的是多个训练任务,训练任务也包括训练集和测试集两部分。在训练时,首先会使用训练集学习元学习器F,再使用测试集训练分类器f。在测试时会使用新的测试任务,使用测试任务中的少量训练集进行快速学习。并使用测试集进行评估。其中学习的具体目标有很多种,具体学习内容包括神经网络的初始化(MAML);优化器的选择(SGD, ADAM);网络架构的搜索等(神经网络层数,卷积核长宽高)。
元学习基础理解_第1张图片
元学习基础理解_第2张图片

损失函数

损失函数:使用训练任务的测试集得到的损失作为损失,将多个任务的损失求和。
元学习基础理解_第3张图片

总体框架

元学习基础理解_第4张图片

MAML方法

MAML好的原因

在元学习中,最经典的一项工作就是MAML。MAML时用来学习如何初始化网络的初始化的参数的。也就是希望通过元学习学习到一种初始化方法,使得这种初始化参数再新的任务中可以仅使用few-shot(少量数据集的训练)就可以完成任务。
MAML的方法取得了很好的效果,至于MAML为什么会取得好的效果,李宏毅老师说了两种可能:1. 可以帮助网络快速完成新的任务的拟合。2. 已经在多任务上训练过的网络已经较为普适。他说是第二种可能起到了作用。
元学习基础理解_第5张图片

公式简单说明

在MAML中,我们主要希望学习 ϕ \phi ϕ来学习如何学习,也就是使用左上角第一个式子来对 ϕ \phi ϕ使用梯度下降更新。而其中 ϕ \phi ϕ即是网络初始化的参数。其中的损失函数是各个训练任务的测试集计算得到的损失的和。而在更新我们最终需要的网络参数时,我们会在初始化结束的参数 ϕ \phi ϕ基础上,使用新任务进行更新完成训练。
元学习基础理解_第6张图片

更新方式

MAML进行更新时首先对于第一个任务采集m个数据样本进行训练一次得到 θ m \theta_{m} θm,在 θ m \theta_{m} θm的基础上再进行一次更新获得新的更新方向(我理解这个新的更新方向指的就是梯度的梯度)。使用这个新的更新方向在原有元初始化参数 ϕ 0 \phi_{0} ϕ0的基础上进行更新,得到 ϕ 1 \phi_{1} ϕ1。之后对于不同的任务依次进行循环训练。
元学习基础理解_第7张图片

RNN学习梯度更新

除了MAML外,李宏毅老师还讲了一种使用LSTM来学习梯度的算法。该算法的初衷在于进行梯度更新时,我们的标准化系数和学习率步长 η \eta η都是固定的,但如果可以在学习时自适应的学习就好了。
首先看一下LSTM的模型如下,包括细胞c用来保存长期信息(更新变化小),隐藏层h用老保存短期信息(更新变化大)和输出y。
元学习基础理解_第8张图片
在这里我们发现细胞c的更新方式与梯度下降是非常像的。因此可以把LSTM变成下图的形式,就可以在学习 LSTM变体的过程中自适应的学习规范化参数和学习率步长了。

元学习基础理解_第9张图片

小结

因此我们可以知道元学习本质就是学习如何学习的过程。这里看到的有学习网络初始化和梯度下降的方法。

你可能感兴趣的:(元强化学习,元学习)