TensorFlow实现center loss

TensorFlow实现center loss:Center loss是ECCV2016中一篇论文提出来的概念,主要思想就是在softmax loss基础上额外加入一个正则项,让网络中每一类样本的特征向量都能够尽量聚在一起。

具体的原理推导等请参考论文,论文作者放出了Caffe实现,网上还能找到mxnet的实现,这里我放出一个TensorFlow版的实现及详细注释,代码很短,如下:


def get_center_loss(features, labels, alpha, num_classes):
     # alpha:中心的更新比例
     # 获取特征长度
     len_features = features.get_shape()[ 1 ]
     # 建立一个变量,存储每一类的中心,不训练 9
     centers = tf.get_variable( 'centers' , [num_classes, len_features], dtype=tf.float32,
         initializer=tf.constant_initializer( 0 ), trainable=False)
     # 将特征reshape成一维
     labels = tf.reshape(labels, [- 1 ])
 
     # 获取当前batch每个样本对应的中心
     centers_batch = tf.gather(centers, labels)
     # 计算center loss的数值
     loss = tf.nn.l2_loss(features - centers_batch)
 
     # 以下为更新中心的步骤
     diff = centers_batch - features
 
     # 获取一个batch中同一样本出现的次数,这里需要理解论文中的更新公式
     unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
     appear_times = tf.gather(unique_count, unique_idx)
     appear_times = tf.reshape(appear_times, [- 1 , 1 ])
 
     diff = diff / tf.cast(( 1 + appear_times), tf.float32)
     diff = alpha * diff
     # 更新中心
     centers = tf.scatter_sub(centers, labels, diff)
 
     return loss, centers


center loss代码注释(caffe新添加层)
http://blog.csdn.net/liyuan123zhouhui/article/details/60139981


你可能感兴趣的:(深度学习)