本文借鉴网上已有的center loss tensorflow版本代码,记录自己在解读该代码时所遇到的知识点以及疑问。
由于python知识浅薄,很多地方不懂,不要见笑。
def get_center_loss(features,labels,alpha=0.5,num_class=10):
#处理数据集为MNIST,所以num_class为10
#get feature dimension 这是为了初始化centers 使用get_shape()来取得维度 但是不明白为什么后面跟着[1] 待研究
len_features = features.get_shape()[1]
#initailizer class center 用get_variable函数 定义centers
#因为center loss里的center是根据公式计算更新而不是梯度下降更新,所以属性trainable为false
centers = tf.get_variable('centers',[num_class,len_features],dtype=tf.flloat32,initializer = tf.constant_initializer(0),trainable = False)
#为了节省计算量,center loss的中心更新都是在mini batch内进行的,所以需要获得mini batch内的centers
#center的获得是通过tf.gather函数来获取。该函数以labels作为index,从以初始化的centers中抽取出minibatch的centers。(tf.gather(params,indexs),根据indexs讲符合indexs的数据从param中抽取出来)
labels = tf.reshape(label,[-1])
centers = tf.gather(centers,labels)
#compute loss 根据公式求得loss 但是不明白为什么这里要沿着axis=-1来取平均值。公式里不是乘了二分之一么?
loss = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(features,centers_batch),axis = -1))
#update centers 接下来是center的更新 按照公式里给的,一步一步求
diff = centers_batch - features
# compute delta c
#获取同一样本类别出现的次数 unique_with_count(labels)这个函数返回的unique_label
是指labels中所有出现过的样本,idx表示unique_label在labels中出现的位置,counts则表示unique_label在labels中出现的次数。
unique_label,idx,count = tf.unique_with_counts(labels)
appear_times = tf.gather(count,idx)
appear_time = tf.reshape(appears_time,shape[-1,1]) #不是很能理解shape[-1,1] 需要看看python的基础知识
#前面的准备工作终于结束,重点来了,根据公式开始计算delta c tf.cast()函数是将函数的输入进行数据类型转换 这里转换为float32
diff = diff /tf.cast((1+appear_times),dtype=tf.float32)
#以上delta 计算完毕,接下来计算中心点的更新
diff = alpha*diff
#update center
#centers[labels[i]]-delta diff[i]
#tf.scatter_sub()在特定位置进行减法,这里是将labels作为索引,把centers和计算后的diff对应逐一相减
center_update_op = tf.scatter_sub(centers,labels,diff)
return loss,centers,center_update_op