最近在研究Focal Loss的Keras实现过程中,由于要实现类似交叉熵的函数,需要将label图像中class=1和class=0的位置拿出来,以寻找在该位置对应的CNN预测值。经过百度查找,发现有人使用tf.where
来实现这个功能,但看官方文档看来好久(下午+晚上)才看明白这个函数是如何使用的,特在此记录下。
之后会专门写一篇有关Focal Loss的Keras实现的博客
TensorFlow中文社区tf.where
tf.where(input, name=None)`
Returns locations of true values in a boolean tensor.
This operation returns the coordinates of true elements in input. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.
For example:
# 'input' tensor is [[True, False] # [True, False]]
# 'input' has two true values, so output has two coordinates.
# 'input' has rank of 2, so coordinates have two indices.
where(input) ==> [[0, 0], [1, 0]]
# `input` tensor is [[[True, False] # [True, False]]
# [[False, True] # [False, True]]
# [[False, False] # [False, True]]]
# 'input' has 5 true values, so output has 5 coordinates.
# 'input' has rank of 3, so coordinates have three indices.
where(input) ==> [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1], [2, 1, 1]]
官方API的解释是返回输入矩阵中true
的位置,但看到第2个例子就有点看不懂了。
返回tensor的第一个维度为input中true的数量,这个比较好理解。第二个维度为矩阵的阶,这里我在看的时候出现了理解误差,我以为是矩阵数学里面的阶,后来才发现Tensorflow中的阶是指矩阵的维数,所以第二个例子的返回值是3列(input的维数为3)。
再看返回tensor的值:
对于第一个例子来说,返回的刚好就是true所在的位置(二维坐标),对于第二个例子就不怎么好理解了。经过各种实验,发现是这样解读的,对于返回值的第一个[0,0,0]表示在第一个维度(编号从0开始,input第一个维度为[[true,false],[true,false]],注意,这是个二维矩阵),true的坐标为[0,0],第三条数据表示在第二个维度且坐标为[0,1]的位置为true。
折腾了这么久,终于搞明白这个函数了!文档写的有点模糊。
官方文档中只有tf.where(input, name=None)
一种用法,在实际应用中发现了另外一种使用方法tf.where(input, a,b)
,其中a,b
均为尺寸一致的tensor
,作用是将a中对应input中true的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值,下面使用代码来说明:
import tensorflow as tf
import numpy as np
sess=tf.Session()
a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])
print(sess.run(tf.equal(a,1)))
print(sess.run(tf.where(tf.equal(a,1),a1,1-a1)))
print(sess.run(tf.where(tf.equal(a,0),a1,1-a1)))
对比两行代码的不同可以发现该函数的作用。
不同之处为tf.equal(a,0)和tf.equal(a,1)
tf.equal()返回tensor中满足条件的位置
参考资料:
tensorflow 关于张量 shape 数组
Tensorflow一些常用基本概念与函数