Keras对多维Tensor的argmax()解析

基础理论

argmax中的axis参数表示在该维度上比较各元素。并且,张量各维度对换,不影响在该维度取argmax()的结果。

a = tf.constant([[[1, 2, 3], [3, 2, 2]], [[10, 11, 12], [4, 5, 6]]])  # a是个2*2*3的tensor
b = tf.argmax(a, axis=1, output_type=tf.int32)
at = tf.transpose(a, [0, 2, 1])  # 将DIM1和DIM2对换,at变成了2*3*2
c = tf.argmax(at, axis=2, output_type=tf.int32)

with tf.Session() as sess:
    print(sess.run(b))
    print(sess.run(c))
print("")

输出结果

[[1 0 0]
 [0 0 0]]
[[1 0 0]
 [0 0 0]]

tf.argmax(a, axis=1)相当于是在a的DIM1上比较,也就是1和3,2和2,3和2,以及10和4,11和5,12和6比较。如果改成tf.argmax(a, axis=0),相当于是a在DIM0上比较,也就是1和10,2和11,3和12,以此类推。

应用场景

比如,目前有分子特征张量input,维度为SampNum × AtomNum × FeatNum,那么,argmax(input, axis=1)将得到维度为SampNum × FeatNum的Tensor,其元素表示各样本分子的各种向量值表征、同种向量的最大者所对应的原子id。
同样的,再来一个,argmax(input, axis=2)将得到维度为SampNum × AtomNum的Tensor,其元素表示各样本分子的各原子的FeatNum种特征中,最大的特征值所对应的特征id。

你可能感兴趣的:(Deep,Learning)