centerloss之中心点计算

以minist为例:
中心点有2种计算方式:
方式1:
人算:
批次取400,平均每批次每个数字有400/10=40个点,如果取40个点的中心为每批次训练的中心,数据太少,误差大, 所以先定义一个队列,队列长度为400*20,前20个周期不训练中心损失,只训练分类损失,把每批次网络计算的点(形状为【400,2】)装进队列,从第20个周期开始,计算队列里对应数字的中心,然后以此中心计算中心损失,因为队列先进先出的特性,随着训练进行,新数据进入队列,早期数据将被抛弃。
这种方式比较麻烦,实现复杂,且训练慢。
方式2:
网络自己算:
自定义centerloss损失函数类,类的可训练参数采用正态分布初始化,此方式如果不自定义backword函数,则训练此网络有2个优化器,一个优化网络参数,另一个优化损失函数类实例参数,定义backword函数需要自己计算梯度,比较麻烦。
启示:
损失函数也可带有可训练参数,prelu就是其中之一。

你可能感兴趣的:(PythonLearn,深度学习)