[解读] GTN: Generative Teaching Networks

链接: https://arxiv.org/abs/1912.07768v1

参考:

https://www.leiphone.com/news/201912/FBZsLSCZSgyD5fIq.html

https://cloud.tencent.com/developer/news/492236

Generative Teaching Network (GTN), 它可以生成数据和训练环境, 让模型在进行目标任务之前先进行一些 SGD 步骤训练, 而使用 GTN 在合成数据训练的神经网络比利用真实数据训练得更快. 在利用 GTN-NAS 搜索新的神经网络架构时, 速度比使用实际数据快9倍, 而且使用的计算量比典型NAS方法少几个数量级.

相关的工作

生成式对抗网络, 它可以生成高分辨率的图像, 但它被训练来模仿真实数据, 而不是学会学习. 数据集蒸馏 (Dataset distillation [1]) 方法利用元学习的思想, 通过在真实数据集上进行测试, 来优化随机生成的数据. 这与本文的工作很接近, 不同点在于本文的合成数据是使用一个生成器来生成的, 好处是能够进一步扩展此方法, 因为合成的数据之间具有某种规则性.

方法

GTN 的核心思想是训练一个数据生成网络, 生成的数据用于学习器的学习, 使得学习器在目标任务上能够快速的达到较好的效果.

[解读] GTN: Generative Teaching Networks_第1张图片

上图是 GTN 结构图. 内循环由生成器 G ( z , y ) G(\mathbf{z},\mathbf{y}) G(z,y) 和学习器组成. z \mathbf{z} z y \mathbf{y} y 分别是高斯噪声和样本标签, 做为生成器的输入. 生成器输出合成数据 x x x, 学习器在合成数据上进行训练, 使用梯度下降算法优化参数. 生成器也可以只将 z z z 做为输入, 输出合成数据和标签.

在外层循环中, 学习器在真实的训练数据上进行评估, 得到外层循环W的元训练损失. 这个损失用来计算元训练梯度, 反馈到内部循环中, 从而优化生成器. 具体来说, 云损失 一方面用来计算内循环超参数的梯度, 另一方面计算生成器的参数梯度.

为了简洁表达算法的思想, 下面自制的伪代码是文中 All shuffled 模式的.

[解读] GTN: Generative Teaching Networks_第2张图片

其中 g g g 是生成器, 其参数为 η \boldsymbol{\eta} η, 其中 f f f 是学习器, 其参数为 θ \boldsymbol{\theta} θ, L inner , L outer \mathcal{L}_{\text{inner}} ,\mathcal{L}_{\text{outer}} Linner,Louter 分别为内循环和外循环的损失, L ( ⋅ ) \mathcal{L}(\cdot) L() 是一个损失函数, x , y \mathbf{x}, \mathbf{y} x,y 分别是真实的数据和标签.

补充

元训练学习: https://lilianweng.github.io/lil-log/2018/11/30/meta-learning.html

课程学习: https://blog.csdn.net/qq_25011449/article/details/82914803

参考

[1] Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba, and Alexei A. Efros. Dataset distillation, 2019b.

本人才疏学浅, 如有遗漏或错误之处, 请多多指教!

你可能感兴趣的:(机器学习,深度学习)