Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(MAML)研读笔记

这里是引用

MAML全文目录

  • 论文地址
  • 摘要
  • 介绍
  • 相关概念
    • model-agnostic
    • N-way K-shot
    • Task
    • 5-way 5-shot的实验设置
  • 算法流程
  • fine-tune算法流程
  • 参考文献

论文地址

https://arxiv.org/abs/1703.03400

摘要

  • 一种与模型无关的元学习算法
  • 适用于适用梯度下降更新训练的模型,分类、回归、强化学习等
  • 元学习的目标:在大量不同的任务上预先训练出一个模型,用这个模型可以在新任务的少量样本上进行训练。预训练,新任务小样本
  • 本文只需要经过几步梯度更新的微调就可以取得不错的效果
  • 效果:在小样本的图像分类成绩最好,小样本回归问题和强化学习策略梯度中也取得了好的成绩

介绍

  • 元学习阶段,使用全部任务做训练样本
  • 没有扩展模型参数,也没有限制模型结构
  • 学习过程是使新任务的损失函数对参数的敏感度最大化,当敏感度高时,参数微小的变化就可以使损失函数大幅度变化
  • 与最先进的小样本学习模型相比,模型使用了更少的参数。回归任务中,在任务可变性下加速强化学习,大大优于初始化的直接预训练
  • 训练过程:先从任务中抽样一个任务 T i T_i Ti,然后从 q i q_i qi中抽取 K K K个样本训练模型,然后得到 T i T_i Ti的损失函数 L T i L_{T_i} LTi,然后在新样本上进程测试,通过 q i q_i qi中抽样的新数据的测试误差随参数的变化情况,来提高模型的性能。
  • 测试误差和训练误差:在所有抽样任务 T i T_i Ti上的测试误差构成了元学习训练阶段的训练误差。
  • 评估阶段(最后阶段):从任务分布 P ( T ) P(T) P(T)中抽样一些新任务,每个任务具有K个样本,通过 K K K个样本的学习,来作为最后的模型评估
  • 元测试的任务在元训练期间
  • 创新:例如,神经网络可能学习广泛适用于 P ( T ) P(T) P(T)中所有任务的内部特征,而不是单个任务的特征表示
  • 方法:就是为了让模型 f f f在抽样的任务上快速适应,同时不产生过拟合。就是去找一组对任务中变化比较敏感的参数
  • 梯度更新方法: θ i ′ = θ − α ∇ θ L T i ( f θ ) \theta_i^\prime=\theta-\alpha\nabla_\theta L_{T_i}(f_\theta) θi=θαθLTi(fθ)
  • f ( θ ) f(\theta) f(θ)是由参数 θ \theta θ表示的模型, L T i ( f θ ) L_{T_i}(f_\theta) LTi(fθ)是任务 T i T_i Ti的损失函数, θ i ′ \theta_i^\prime θi是是在 T i T_i Ti经过一次或者多次梯度下降更新得到的参数更新
  • 优化方法:
    • 模型的参数通过优化所有任务上的 f θ ′ f_\theta^{'} fθ来进行更新, min ⁡ θ ∑ T i ∼ p ( T ) L T i ( f θ ′ ) = ∑ T i ∼ p ( T ) α ∇ θ L T i ( f θ ) L T i ( f θ − α ∇ θ L T i ( f θ ) ) \min_\theta\sum_{T_i \sim p(T)}L_{T_i}(f_\theta^{'})=\sum_{T_i \sim p(T)}\alpha\nabla_\theta L_{T_i}(f_\theta)L_{T_i}(f_{\theta-\alpha\nabla_\theta L_{T_i}(f_\theta)}) θminTip(T)LTi(fθ)=Tip(T)αθLTi(fθ)LTi(fθαθLTi(fθ))
      – 元学习阶段的优化是在模型参数 θ \theta θ上进行的,而上述目标是使用更新过的 θ ′ \theta^{'} θ得到的,提出的方法要在新任务上通过一次或几次梯度更新来优化模型参数
    • meta的优化,通过梯度下降来进行更新的: θ ← θ − β ∇ θ ∑ T i ∼ p ( T ) L T i ( f θ ) \theta \gets \theta-\beta\nabla_\theta\sum_{T_i \sim p(T)} L_{T_i}(f_\theta) θθβθTip(T)LTi(fθ)

相关概念

model-agnostic

  • model-agnostic:
    • model-agnostic即模型无关。
    • MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner用于训练base-learner
    • meta-learner即MAML的精髓
    • base-learner则是在目标数据集上被训练,并实际用于预测任务的真正的数学模型。
    • 大多数深度学习模型都可以作为base-learner嵌入MAML
    • meta-learner → \rightarrow base-learner

N-way K-shot

  • N-way K-shot:
    • 是few-shot learning(小样本学习)中常用的实验设置。小样本学习指利用很少的被标记数据训练数学模型的过程(MAML擅长的)
    • N-way指训练数据中有N个类别
    • K-shot指每个类别下有K个被标记数据

Task

  • Task:
    • 假设一个场景:我们需要利用MAML训练一个数学模型 M f i n e − t u n e M_{fine-tune} Mfinetune(fine-tune为微调),目的是对未知标签的图片做分类,类别包括 P 1 ∼ P 5 P_1 \sim P_5 P1P5每个类别有5个已标注样本用于训练。另外每个类别有15个已标注样本用于测试,一共100个已标注样本)。我们的训练数据除了 P 1 ∼ P 5 P_1 \sim P_5 P1P5 中已标注的样本外,还包括另外10个类别的图片 C 1 ∼ C 10 C_1 \sim C_{10} C1C10(每种类别有30个已标注样本,一共300个已标注),用于帮助训练元学习模型 M m e t a M_{meta} Mmeta 。我们的实验设置为5-way 5-shot也就是 C 1 ∼ C 10 C_1 \sim C_{10} C1C10的随机抽取的5个类别中的样本是先来训练元学习模型 M m e t a M_{meta} Mmeta
    • 训练过程大概为:
      • MAML首先利用 C 1 ∼ C 10 C_1 \sim C_{10} C1C10 的数据集训练元模型 M m e t a M_{meta} Mmeta,再在 P 1 ∼ P 5 P_1 \sim P_5 P1P5的数据集上精调(fine-tune)得到最终的模型 M f i n e − t u n e M_{fine-tune} Mfinetune ,下面的算法流程主要就是这部分
    • C 1 ∼ C 10 C_1 \sim C_{10} C1C10meta-train classes C 1 ∼ C 10 C_1 \sim C_{10} C1C10 包含的共计300个样本,即 D m e t a − t r a i n D_{meta-train} Dmetatrain ,是用于训练 M m e t a M_{meta} Mmeta的数据集;
    • 与之相对的, P 1 ∼ P 5 P_1 \sim P_5 P1P5meta-test classes P 1 ∼ P 5 P_1 \sim P_5 P1P5 包含的共计100个样本,即 D m e t a − t e s t D_{meta-test} Dmetatest ,是用于训练和测试 M f i n e − t u n e M_{fine-tune} Mfinetune 的数据集

5-way 5-shot的实验设置

  • 5-way 5-shot的实验设置:
    • task T,相当于普通深度学习模型训练过程中的一条训练数据:
      • 在训练 M m e t a M_{meta} Mmeta阶段,从 C 1 ∼ C 10 C_1 \sim C_{10} C1C10中随机取5个类别,每个类别再随机取20个已标注样本,组成一个task T
      • 每个类别随机取20个已标注的样本中,其中的5个已标注的样本称为Tsupport set,另外15个样本称为Tquery set
      • 这个task T,相当于普通深度学习模型训练过程中的一条训练数据。类似于SGD,要搞batch,即反复从训练数据分布中抽取若干(分布解释看李沐的)个这样的task T,组成一个batch
    • 训练 M f i n e − t u n e M_{fine-tune} Mfinetune阶段,task、support set、query set的含义和训练 M m e t a M_{meta} Mmeta阶段相同

算法流程

Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks(MAML)研读笔记_第1张图片

  • MAML预训练阶段的算法的目的:得到模型 M m e t a M_{meta} Mmeta
  • 第一行的Require:
    • 指的是在 D m e t a − t r a i n D_{meta-train} Dmetatraintask的分布。这里就是反复随机抽取的task T,形成若干个T(例如抽取1000个)组成的task池,作为MAML的训练集
    • 这里有一个问题:训练样本 D m e t a − t r a i n D_{meta-train} Dmetatrain数量有限,要组合形成那么多的task,岂不是不同task之间会存在样本的重复?或者某些taskquery set会成为其他tasksupport set
      • 答:MAML的目的,就在于fast adaptation,即通过对大量task的学习,获得足够强的泛化能力,从而面对新的、从未见过的task时,通过fine-tune就可以快速拟合。task之间,只要存在一定的差异即可
    • MAML的训练是基于task的,而这里的每个task就相当于普通深度学习模型训练过程中的一条训练数据
  • 第二行的require:
    • α , β \alpha,\beta α,βstep size其实就是学习率
    • MAML是基于二重梯度(gradient by gradient)(下面有讲什么是二重梯度),每次迭代包括两次参数更新的过程,所以有两个学习率需要调整
  • 步骤1:
    • 随机初始化模型的参数 θ \theta θ
  • 步骤2:
    • 是一个循环,可以理解为一轮迭代过程或一个epoch,预训练的过程是可以有多个epoch的
  • 步骤3:
    • 从分布中随机采样若干数量的task(例如5个),形成一个batch
  • 步骤4-7:
    • 是第一次梯度更新的过程
    • copy了原模型,然后在copy的模型上计算出新的参数,用于第二个梯度的计算过程中
    • 每一个task更新一次参数(5个task就是更新5次),就是一个batch结束,这个在算法中可以反复执行多次
  • 步骤5:
    • 使用batch中某个task中的support set(一个task就是k个,例如5w5k实验设置中的就是5个已标注的样本),来计算参数的梯度,总的support setN*K个(5w5k就是25个,取5个类别,从 C 1 ∼ C 10 C_1 \sim C_{10} C1C10中随机取5个类别,再从每个类别的20个已标注样本中取5个
    • Loss方法,回归任务,就是MSE;分类任务,就是cross-entropy(交叉熵)(李沐视频有讲什么是交叉熵)
  • 步骤6:
    • 第一次参数梯度的更新
  • 步骤4-7完,MAML完成了第一次梯度更新。根据第一次梯度更新得到的参数,计算第二次梯度更新(步骤8)。第二次的梯度更新时计算出的梯度,直接通过SGD作用在原模型(下面说)上,也就是模型真正用于更新其参数的梯度。也就是第一次梯度更新是为了第二次梯度更新,第二次梯度才是真正更新模型参数。
    • 二重梯度:
      • 原模型 θ a \theta_a θa,先复制一份原模型, θ a → c o p y θ b \theta_a \rightarrow_{copy} \theta_b θacopyθb,得到 θ b \theta_b θb
      • θ b \theta_b θb上,做反向传播(这个要搞懂什么意思)及更新参数,得到第一次梯度更新的结果 θ b ′ \theta_b^{'} θb
      • θ b ′ \theta_b^{'} θb上,计算第二次梯度更新,计算出来,不更新 θ b ′ \theta_b^{'} θb,更新原模型 θ a \theta_a θa
      • 这里需要理解是copy的模型是每一个task都会copy一份,如10个task会copy10个临时模型,在10个临时模型上,在各自的task上独立更新一个梯度(步骤4-7),然后整合起来用于步骤8,也就是更新原模型
      • 这样做,就是因为每一个task都会更新一次参数,用原模型,会导致使用上一个task的更新过的参数
      • 从原模型的角度来看,只进行了一次梯度更新(步骤8),但是第二次梯度更新(步骤8)依赖于第一次(步骤4-7)
      • 总结:第一次梯度,不作用于原模型,第二次梯度用于原模型
  • 步骤8:
    • 第二次梯度更新的过程
    • 与步骤7不同处:
      • 1.不是分别利用每个task的Loss更新梯度,直接和常用的模型训练一样,计算一个batch的loss和,对梯度进行随机梯度下降SGD
      • 2.这里的样本是taskquery set(如5w5k中 15*5个样本),是为了增强模型在task上的泛化,避免过拟合support set
      • 该步骤结束后,即完成在当前batch中的训练,回到步骤3,采样下一个batch

fine-tune算法流程

  • 完成以上步骤,就是MAML预训练得到 M m e t a M_{meta} Mmeta的全部过程
  • 接下来要完成的就是面对新的task,在 M m e t a M_{meta} Mmeta的基础上,精调得到 M f i n e − t u n e M_{fine-tune} Mfinetune
    • 步骤1中,fine-tune不再随机初始化参数,而是利用训练好的 M m e t a M_{meta} Mmeta初始化参数
    • 步骤3中,fine-tune只需要抽取一个task进行学习,也不用形成batch,fine-tune利用这个tasksupport set训练模型,利用query set测试模型。
      • 实际操作时,会在 D m e t a − t e s t D_{meta-test} Dmetatest上随机抽取许多个task(例如500个),分别微调模型 M m e t a M_{meta} Mmeta,并对最后的测试结果进行平均,从而避免极端情况
    • 没有步骤8,因为taskquery set是用来测试模型的,标签对模型是未知的。因此这个过程没有第二次梯度更新,直接用第一次梯度计算的结果更新参数

参考文献

1.Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 作者:徐不知
2.[meta-learning] 对MAML的深度解析 作者:周威
3.MAML 论文及代码阅读笔记 作者:Rust-in
4.MAML原论文

你可能感兴趣的:(Meta-learning,元学习,#,论文研读,深度学习,人工智能,算法,python)