ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法

背景

大多数深层神经网络(CNN)往往消耗巨大的计算资源和存储空间为了将模型部署到性能受限的设备(如移动设备),通常需要加速网络的压缩现有的一些加速压缩算法,如知识蒸馏等,可以通过训练数据获得有效的结果。然而,在实际应用中,由于隐私、传输等原因,训练数据集通常不可用因此,作者提出了一种不需要原始训练数据的模型压缩方法。

原理

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第1张图片点击添加图片描述(最多60个字)

上图是本文提出的总体结构通过一个给定的待压缩网络(教师网络),作者训练一个生成器生成与原始训练集分布相似的数据然后,利用生成的数据,基于知识提取算法对学生网络进行训练,从而实现无数据的模型压缩。

那么,在没有数据的情况下,如何在给定的教师网络上训练一个可靠的生成器呢作者提出了以下三个损失来指导发电机的学习。

(1)在图像分类任务中,对于真实数据,网络的输出往往接近一个热向量其中,分类类别的输出接近于1,其他类别的输出接近于零因此,如果生成器生成的图像接近真实数据,那么它在教师网络上的输出应该类似于一个热向量因此,作者提出了一个 One-hotloss:

其中YT是通过教师网络生成的图片的输出,T是伪标签,并且由于生成的图片不具有标签,所以作者将YT中的最大值设置为伪标签。Hcross表示交叉熵函数。

(2)另外,在神经网络中,输入真实数据往往比输入的随机噪声在特征图上有更大的响应值因此,作者建议激活损失约束生成的数据:

其中fT表示通过教师网络提取生成的数据的特征,||·||1表示|1范数。

(3)此外,为了使网络得到更好的训练,训练数据往往需要类别平衡因此,为了平衡同一类别中生成的数据,引入信息熵损失来度量类别平衡度:

其中,Hinfo表示信息熵,yT表示每张图片的输出如果信息熵较大,则对输入的图片集中的每个类别的平均数进行平均,从而确保生成的图片类别的平均数。

最后,结合以上三个损耗函数,可以得到发电机培训使用的损耗:

通过优化上述损失,您可以训练生成器,然后通过生成器生成的样本执行知识蒸馏在知识提取中,要压缩的网络(教师网络)通常具有较高的精度,但存在冗余参数学生网络是一个轻量级设计和随机初始化网络利用教师网络的输出来指导学生网络的输出,可以提高学生网络的精度,达到模型压缩的目的这个过程可以用以下公式表示:

其中,ys和yt分别表示学生网络和教师网络的输出,Hcross表示交叉熵函数。

算法1表示项目方法的流程首先,通过优化上述损耗,获得与原始数据集具有相似分布的发生器其次,通过生成器生成的图像,将教师网络的输出通过知识蒸馏迁移到学生网络中学生网络的参数较少,支持无数据压缩方法。

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第2张图片点击添加图片描述(最多60个字)

结果

MNIST数据集上的分类结果。

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第3张图片点击添加图片描述(最多60个字)

所提出的无数据学习方法的不同组成部分的有效性。

CIFAR数据集上的分类结果。

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第4张图片点击添加图片描述(最多60个字)

CelebA数据集上的分类结果

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第5张图片点击添加图片描述(最多60个字)

在各种数据集上的分类结果。

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第6张图片点击添加图片描述(最多60个字)

 可视化每个类别中的平均图像(从0至9)

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第7张图片点击添加图片描述(最多60个字)

第一卷积层中过滤器的可视化,在MNIST数据集上学习。第一行显示训练有素的过滤器,使用原始训练数据集,并且底线显示使用通过所提出的方法生成的样本获得的过滤器。

ICCV 2019推荐Pytorch实现一种无需原始训练数据的模型压缩算法_第8张图片点击添加图片描述(最多60个字)

总结

常规方法需要原始训练数据集,用于微调压缩的深度神经网络具有可接受的精度。但是,训练集和给定深度网络的详细架构信息,由于某些隐私和传输限制,通常无法使用。

作者在本文中,我们提出了一个新颖的框架来训练生成器以逼近原始没有训练数据的数据集。然后,一个便携式网络通过知识提炼方案可以有效地学习。

在基准数据集上的实验表明,所提出的方法DAFL方法能够无需任何培训即可学习便携式深度神经网络数据。

相关论文源码下载地址:关注“图像算法”微信公众号 回复“DAFL”

 

你可能感兴趣的:(图像算法)