好记性不如烂笔头。
总结: 博主认为,正是由于Accuracy计算时,采用的是相对最大概率,所以存在计算的loss与accuracy不成正比关系,即当accuracy很高时,有可能存在loss很高,这是由于,我们的目的是:使得最大下标的概率很大,接近于1,但实际是相对最大,导致accuracy很高,loss也很高,如对于三分类,我们期望的预测结果是[1, 0, 0]或者[0.95, 0.02, 0.03],但可能实际得到的是[0.45, 0.25, 0.3],此时accuracy的计算结果不变,但是loss值却变动很大。
改进:因此一般训练时,看loss为主。但在发表的论文中存在以accuracy为准则(如VGG论文)。如果以accuracy为主,可以添加约束条件,例如,对相对最大值进行判断。即如果相对最大值大于0.8,认为正确,否则不正确。
其中, L是单个样本的loss,k为分类的数目。如果batch size的话需要对loss求均值。
python代码为:
def my_categorical_crossentropy(y_true:'list', y_pred:'list'):
'''
return: list for each example
'''
loss = []
for ii in range(len(y_true)):
tem_crossentropy_loss = 0
for jj in range( len(y_true[ii]) ):
tem_crossentropy_loss += (-1 * y_true[ii][jj] * np.log( y_pred[ii][jj] + 1e-10))
loss.append( tem_crossentropy_loss )
return loss
# category cross entropy loss in keras
# return list for each example
# categorical_crossentropy input is Tensor in TensorFlow
from keras.losses import categorical_crossentropy
如果需要计算均值
(注意:keras中 categorical_crossentropy输入是Tensorflow中的Tensor):
loss_list = my_categorical_crossentropy(y_true=y_true, y_pred=y_pred)
loss = np.mean( loss_list )
## or
sess = tf.InteractiveSession()
tf_loss = categorical_crossentropy( y_true=tf.convert_to_tensor(y_true, dtype=tf.float32), y_pred=tf.convert_to_tensor(y_pred, dtype=tf.float32) ) )
loss_list = sess.run( tf_loss )
loss = np.mean( loss_list )
sess.close()
对于一个样本的预测值,其最大概率的下标作为该样本的预测结果,然后与y_true的最大下标进行对比,如果下标相同,则预测正确。
(注意: 博主认为,正是由于该计算方式的原因,所以存在计算的loss与accuracy不成正比关系,即当accuracy很高时,有可能存在loss很高,这是由于目的是:使得最大下标的概率很大,接近于1,但实际是相对最大,导致accuracy很高,loss也很高)
python代码实现:
def my_categorical_accuracy(y_true:'list', y_pred:'list'):
y_ = np.argmax( y_true, axis=-1 )
y = np.argmax(y_pred, axis=-1)
acc = np.equal(y_, y)
acc = np.mean(acc)
return acc
# accuracy in keras
# return list for each example
# categorical_accuracy input is Tensor in TensorFlow
from keras.metrics import categorical_accuracy
通过实验验证上述代码的正确性。
输入:
y_pred = [[0.9, 0.05, 0.05], [0.49, 0.2, 0.31], [0.1, 0.2, 0.7], [0.15, 0.25, 0.6]]
y_true = [[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
结果: