正确率计算

def acc(output, label):
    # output: (batch, num_output) float32 ndarray
    # label: (batch, ) int32 ndarray
    return (output.argmax(axis=1) == label.astype(‘float32’)).mean().asscalar()

在Gluon文档里有这个计算accuracy的函数,就一行看不懂,分析一下。

首先argmax,argmax的意思是返回最大值的坐标。

  1. axis缺省为全局最大(直接用报错,可以np.argmax(a) )
  2. axis = 0 为每列最大
  3. axis = 1为每行最大
x = nd.array(((1,2,3),(3,4,5)))

>>>  x.argmax(axis=1)

[2. 2.]
2 @cpu(0)>

>>> x.argmax(axis=0)

[1. 1. 1.]
3 @cpu(0)>

=======================================================
没学过python诶。为什么函数参数可以这么写
见:http://www.runoob.com/python/python-functions.html 中的关键字参数

=======================================================
所以 output.argmax(axis=1) 返回的是每行最大的index。

然后

output.argmax(axis=1)  ==   label.astype('float32')

是一个0 1 数组

最后计算mean就行了,其中asscalar()把结果变成标量

你可能感兴趣的:(mxnet)