centerloss tensorflow代码分析以及疑点

本文借鉴网上已有的center loss tensorflow版本代码,记录自己在解读该代码时所遇到的知识点以及疑问。

由于python知识浅薄,很多地方不懂,不要见笑。

def get_center_loss(features,labels,alpha=0.5,num_class=10):

#处理数据集为MNIST,所以num_class为10

#get feature dimension 这是为了初始化centers 使用get_shape()来取得维度 但是不明白为什么后面跟着[1] 待研究

len_features = features.get_shape()[1]

#initailizer class center 用get_variable函数 定义centers

#因为center loss里的center是根据公式计算更新而不是梯度下降更新,所以属性trainable为false

centers = tf.get_variable('centers',[num_class,len_features],dtype=tf.flloat32,initializer = tf.constant_initializer(0),trainable = False)

#为了节省计算量,center loss的中心更新都是在mini batch内进行的,所以需要获得mini batch内的centers

#center的获得是通过tf.gather函数来获取。该函数以labels作为index,从以初始化的centers中抽取出minibatch的centers。(tf.gather(params,indexs),根据indexs讲符合indexs的数据从param中抽取出来

labels = tf.reshape(label,[-1])

centers = tf.gather(centers,labels)

#compute loss 根据公式求得loss 但是不明白为什么这里要沿着axis=-1来取平均值。公式里不是乘了二分之一么?

loss = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(features,centers_batch),axis = -1))

#update centers 接下来是center的更新 按照公式里给的,一步一步求

diff = centers_batch - features

# compute delta c

#获取同一样本类别出现的次数  unique_with_count(labels)这个函数返回的unique_label

是指labels中所有出现过的样本,idx表示unique_label在labels中出现的位置,counts则表示unique_label在labels中出现的次数。

unique_label,idx,count = tf.unique_with_counts(labels)

appear_times = tf.gather(count,idx)

appear_time = tf.reshape(appears_time,shape[-1,1])      #不是很能理解shape[-1,1] 需要看看python的基础知识

#前面的准备工作终于结束,重点来了,根据公式开始计算delta c  tf.cast()函数是将函数的输入进行数据类型转换 这里转换为float32

diff = diff /tf.cast((1+appear_times),dtype=tf.float32)

#以上delta 计算完毕,接下来计算中心点的更新

diff = alpha*diff

#update center

#centers[labels[i]]-delta diff[i]

#tf.scatter_sub()在特定位置进行减法,这里是将labels作为索引,把centers和计算后的diff对应逐一相减

center_update_op = tf.scatter_sub(centers,labels,diff)

return loss,centers,center_update_op

你可能感兴趣的:(centerloss tensorflow代码分析以及疑点)