TensorFlow函数:tf.where()

focal loss中要实现交叉熵函数,需要将label图像中不同class对应的位置取出,经查询需要用到tf.where函数,随即记录如下。

Tensorflow 官方文档介绍

Returns locations of true values in a boolean tensor.This operation returns the coordinates of true elements in input. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.

For example:

# 'input' tensor is [[True, False]
#                    [True, False]]
# 'input' has two true values, so output has two coordinates.
# 'input' has rank of 2, so coordinates have two indices.
where(input) ==> [[0, 0],
                  [1, 0]]

# `input` tensor is [[[True, False]
#                     [True, False]]
#                    [[False, True]
#                     [False, True]]
#                    [[False, False]
#                     [False, True]]]
# 'input' has 5 true values, so output has 5 coordinates.
# 'input' has rank of 3, so coordinates have three indices.
where(input) ==> [[0, 0, 0],
                  [0, 1, 0],
                  [1, 0, 1],
                  [1, 1, 1],
                  [2, 1, 1]]

Args:

  • input: A Tensor of type bool.

  • name: A name for the operation (optional).

Returns:

  • A Tensor of type int64.

用法二

官方文档中只有tf.where(input, name=None)一种用法,在实际应用中发现了另外一种使用方法tf.where(input, a,b),其中a,b均为尺寸一致的tensor,作用是将a中对应inputtrue的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值。举例:

import tensorflow as tf
import numpy as np
sess=tf.Session()

a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])
index = tf.where(a)

print(sess.run(index))

print(sess.run(tf.equal(a,1)))
print(sess.run(tf.where(tf.equal(a,1),a1,1-a1)))

print(sess.run(tf.equal(a,0)))
print(sess.run(tf.where(tf.equal(a,0),a1,1-a1)))

输出分别为

[[0 0]
 [1 1]
 [1 2]]

[[ True False False]
 [False  True  True]]
 
[[ 3 -1 -2]
 [-3  5  6]]
 
[[False  True  True]
 [ True False False]]
 
[[-2  2  3]
 [ 4 -4 -5]]

参考:https://blog.csdn.net/a_a_ron/article/details/79048446

你可能感兴趣的:(Tensorflow)