TensorFlow(10)——tf.where()、tf.gather()、tf.squeeze()

本文主要介绍在使用TensorFlow编写损失函数时,经常会使用的TensorFlow函数。

文章目录

  • tf.where()
  • tf.gather()
  • tf.squeeze()
  • tf.less()、tf.greater()、tf.equal()等比较函数

tf.where()

tf.where(condition, x=None, y=None, name=None)

参数:

  • condition : 布尔型张量
  • x :与condition具有相同shape的张量;或当condition为1维时,x可以是高维张量,但x的第一个维度必须与condition的size相同。
  • y :与x具有相同shape的张量
  • name : A name of the operation (optional)

该函数具有元素筛选作用,下面介绍其具体用法

1. x 和 y 均不为None,且x、y、condition三者具有相同维度

tf.where(condition, x, y)的作用是: 将condition中的 True 位置元素替换为 x 中对应位置元素, False 位置元素替换为 y 中对应位置元素,以得到新的张量。

import tensorflow as tf
import numpy as np

condition = tf.less(np.array([[1, 3, 5], [2, 6, 4]]), 4)
x = tf.constant([[1, 1, 1], [2, 2, 2]])
y = tf.constant([[3, 3, 3], [4, 4, 4]])
new_tensor = tf.where(condition, x, y)
with tf.Session() as sess:
    print(sess.run(condition))
    print(sess.run(new_tensor))

输出结果:

[[ True  True False]
 [ True False False]]
[[1 1 3]
 [2 4 4]]

2. x 和 y 均不为None,且condition为1维,x和y为高维

tf.where(condition, x, y)的作用可以这样理解: 假设1维condition有m个元素,先将 x 和 y 分别按照第一维度拆分成m组,再将condition中的 True 位置元素替换为 x 拆分后中对应位置分组, False 位置元素替换为 y 拆分后中对应位置分组,以得到新的张量。多说无益,看完具体例子你就明白了。

import tensorflow as tf
import numpy as np

condition = tf.less(np.array([0, 1]), 1)
x = tf.constant([[1, 1, 1], [2, 2, 2]])
y = tf.constant([[3, 3, 3], [4, 4, 4]])
new_tensor = tf.where(condition, x, y)
with tf.Session() as sess:
    print(sess.run(condition))
    print(sess.run(new_tensor))

输出结果:

[True False]
[[1 1 1]
 [4 4 4]]

分析:condition中有两个元素,所以对 x 进行拆分并按照ndarray的记法,简记为:x[0] = [1, 1, 1]、x[1] = [2, 2, 2];同样将y拆分为:y[0] = [3, 3, 3]、y[1] = [4, 4, 4]。condition[0] =True,所以将其替换为x[0];condition[1] =False,所以将其替换为y[1],最终得到上述结果。

3. x 和 y 均为None

tf.where(condition)将返回condition中为True的元素的索引

import tensorflow as tf
import numpy as np
condition1 = tf.less(np.array([1, 3, 5]), 4)  # 一维
condition2 = tf.less(np.array([[1, 3, 5], [2, 6, 4]]), 4)  # 二维
condition3 = tf.less(np.array([[[1, 3, 5], [2, 6, 4]], [[7, 1, 3], [6, 0, 1]]]), 4)  # 三维
with tf.Session() as sess:
    print('condition1 =', sess.run(condition1))
    print('tf.where(condition1) =', sess.run(tf.where(condition1)))
    print('condition2 =', sess.run(condition2))
    print('tf.where(condition2) =', sess.run(tf.where(condition2)))
    print('condition3 =', sess.run(condition3))
    print('tf.where(condition3) =', sess.run(tf.where(condition3)))
    

输出结果:

condition1 = [ True  True False]
tf.where(condition1) = [[0]
 [1]]
condition2 = [[ True  True False]
 [ True False False]]
tf.where(condition2) = [[0 0]
 [0 1]
 [1 0]]
condition3 = [[[ True  True False]
  [ True False False]]

 [[False  True  True]
  [False  True  True]]]
tf.where(condition3) = [[0 0 0]
 [0 0 1]
 [0 1 0]
 [1 0 1]
 [1 0 2]
 [1 1 1]
 [1 1 2]]

分析:不管condition是多少维,tf.where(condition)返回的张量总是二维的,其行数为condition中True元素的个数,每一行对应一个True元素的索引。

tf.gather()

我们知道,ndarray和list都可以直接通过索引进行切片,但tensor却不行。不过TensorFlow提供了多个函数来进行张量切片,tf.gather()就是其中一种,其调用形式如下:

tf.gather(params, indices, validate_indices=None, name=None, axis=0)

Gather slices from params axis axis according to indices,即
从’params’的’axis’维根据’indices’的参数值获取切片。就是在axis维根据indices取某些值,最终得到新的tensor

主要参数:

  • params:要进行切片的ndarray或list或tensor等
  • indices:索引向量,其类型可以是ndarray、list、tensor等
  • axis : 对哪个轴进行切片

1. params 的维数为1

import tensorflow as tf
import numpy as np

# params = np.random.randint(1, 10, 5)
# params = [2, 3, 4, 5, 6, 7]
params = tf.constant([2, 3, 4, 5, 6, 7])

# indices = np.array([2, 1, 4, 2])
# indices = [2, 1, 4, 2]
indices = tf.constant([2, 1, 4, 2])

tensor1 = tf.gather(params, indices)
with tf.Session() as sess:
    # print(params)
    print(sess.run(params))
    print(sess.run(tensor1))

输出结果:

[2 3 4 5 6 7]
[4 3 6 4]

分析:根据indices逐一取出params中对应索引的元素,并组成新的张量。

2. params 的维数为2

import tensorflow as tf
import numpy as np

params = np.random.randint(1, 10, (4, 5))
indices = tf.constant([2, 1, 0, 2])
tensor0 = tf.gather(params, indices, axis=0)
tensor1 = tf.gather(params, indices, axis=1)
with tf.Session() as sess:
    print('params =', params)
    print('tensor0 =', sess.run(tensor0))
    print('tensor1 =', sess.run(tensor1))

输出结果:

params = [[5 1 4 7 2]
 		  [1 8 9 1 7]
 		  [2 1 8 7 2]
 		  [8 9 5 8 7]]
tensor0 = [[2 1 8 7 2]
 		   [1 8 9 1 7]
		   [5 1 4 7 2]
		   [2 1 8 7 2]]
tensor1 = [[4 1 5 4]
 		   [9 8 1 9]
 	 	   [8 1 2 8]
 		   [5 9 8 5]]

对于二维params,
当indices是标量且是张量时,得到的结果不会降维;
当indices是标量且是ndarray时,得到的结果会降维。

import tensorflow as tf
import numpy as np

params = np.random.randint(1, 10, (3, 4))
indices1 = tf.constant([2])
indices2 = 2
tensor1 = tf.gather(params, indices1, axis=0)
tensor2 = tf.gather(params, indices2, axis=0)
with tf.Session() as sess:
    print('params =', params)
    print('tensor1 =', sess.run(tensor1))
    print('tensor2 =', sess.run(tensor2))

输出结果:

params = [[9 2 1 7]
	      [7 8 2 3]
 		  [9 7 2 9]]
tensor1 = [[9 7 2 9]]
tensor2 = [9 7 2 9]

tf.squeeze()

tf.squeeze(input, axis=None, name=None, squeeze_dims=None)

该函数返回一个张量,这个张量是将原始input中所有维度为1的那些维都删掉的结果。
axis可以用来指定要删掉的为1的维度,此处要注意指定的维度必须确保其是1,否则会报错。

import tensorflow as tf


input_tensor = tf.ones((2, 1, 1, 3, 2))
new_tensor1 = tf.squeeze(input_tensor)
new_tensor2 = tf.squeeze(input_tensor, [1])
with tf.Session() as sess:
    print(sess.run(tf.shape(input_tensor)))
    print(sess.run(tf.shape(new_tensor1)))
    print(sess.run(tf.shape(new_tensor2)))

输出结果:

[2 1 1 3 2]
[2 3 2]
[2 1 3 2]

实际上,tf.squeeze()和np.squeeze()的作用是一样的,都是去除指定的冗余维度(维度为1的维度)。这两个函数的区别是:作用对象不同,前者是张量,后者是ndarray。

tf.less()、tf.greater()、tf.equal()等比较函数

这几个函数用于逐元素比较两个张量的大小,并返回比较结果(True or False)构成的布尔型张量。下面以tf.less()为例:

tf.less(x, y, name=None

tf.less()返回了两个张量各元素比较(x

注意:

  • tf.less()支持broadcast机制;
  • tf.less(x, y)中的 x 和 y 可以是tensor、ndarray、list等。
x = tf.constant([[1, 2, 3], [4, 5, 6]])
y1 = tf.constant([[2, 1, 2], [2, 6, 7]])
y2 = tf.constant([3, 6, 9])
y3 = tf.constant([3])
with tf.Session() as sess:
    print(sess.run(tf.less(x, y1)))
    print(sess.run(tf.less(x, y2)))
    print(sess.run(tf.less(x, y3)))

输出结果:

[[ True False False]
 [False  True  True]]
[[ True  True  True]
 [False  True  True]]
[[ True  True False]
 [False False False]]

总结:

  • tf.less(x, y) —— x < y 为True
  • tf.equal(x, y) —— x == y 为True
  • tf.greater(x, y) —— x > y 为True
  • tf.greater_equal(x, y) —— x >= y 为True
  • tf.less_equal(x, y) —— x <= y 为True

你可能感兴趣的:(TensorFlow)