logits是网络的输出,logits.shape=(batch_size, w, h, 21),21类语义标签。
pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)#shape=(?, ?, ?, 1)
我们用numpy解释argmax和expand_dims这两个函数:
import numpy as np
#(2, 2, 4, 3)
x = np.array([[[[31,20,10],
[20,43,30],
[40,10,62],
[40,60,76]],
[[10,72,20],
[81,30,40],
[97,50,70],
[40,50,68]]],
[[[10,22,10],
[20,30,81],
[40,10,62],
[40,65,30]],
[[10,72,20],
[81,30,40],
[97,50,70],
[40,50,68]]]])
#(2, 2, 4)
y1 = np.argmax(x, axis=3)
'''[[[0 1 2 2]
[1 0 0 2]]
[[1 2 2 1]
[1 0 0 2]]]'''
#(2, 2, 4, 1)
y2 = np.expand_dims(y1, axis=3)
'''[[[[0]
[1]
[2]
[2]]
[[1]
[0]
[0]
[2]]]
[[[1]
[2]
[2]
[1]]
[[1]
[0]
[0]
[2]]]]'''
y2的值是每个一维列表的最大值的下标,如第一个值为0,是因为[31,20,10]中最大元素31的下标是0。
batch_size为1时的网络输出:
[[[31,20,10],
[20,43,30],
[40,10,62],
[40,60,76]],
[[10,72,20],
[81,30,40],
[97,50,70],
[40,50,68]]]
注意图像大小是2×4,而不是4×3或3×4。所以[31,20,10]是第一个像素被归类为第0类、第1类、第2类的概率。因为31最大,所以该像素的语义标签被归类为0。这样就可以解释y2:
[[[0]
[1]
[2]
[2]]
[[1]
[0]
[0]
[2]]]
batch_size=1时,它是指一个2×4的图像的第(0,0)个像素标签为0、第(0,1)个像素标签为1、... 、第(1,3)个像素标签为2。