tensorflow使用神奇索引(整数数组索引)

一、numpy中的神奇索引

神奇索引是Numpy中的术语,用于描述使用整数数组进行数据索引。神奇索引实现的是给定多个矩阵的索引位置(使用列表或数组指定),获得这些指定位置的值,该方法不同于普通的切片索引,能够离散的获得矩阵元素的值,神奇索引的具体描述可参考《利用python进行数据分析》一书的第四章。

1、获得指定的行构成的子集(顺序)

import numpy as np
arr = np.empty((8,4))
for i in range(8):
    arr[i] = i
print("arr:\n",arr)
print("arr[[2,1,3]]:\n",arr[[2,1,3]])

 

结果为:

arr:

 [[0. 0. 0. 0.]

 [1. 1. 1. 1.]

 [2. 2. 2. 2.]

 [3. 3. 3. 3.]

 [4. 4. 4. 4.]

 [5. 5. 5. 5.]

 [6. 6. 6. 6.]

 [7. 7. 7. 7.]]

arr[[2,1,3]]:

 [[2. 2. 2. 2.]

 [1. 1. 1. 1.]

 [3. 3. 3. 3.]]

2、获得指定的行构成的子集(逆序,从尾部进行选择)

import numpy as np
arr = np.empty((8,4))
for i in range(8):
    arr[i] = i
print("arr:\n",arr)
print("arr[[-2,-1,-3]]:\n",arr[[-2,-1,-3]])

结果为:

arr:

 [[0. 0. 0. 0.]

 [1. 1. 1. 1.]

 [2. 2. 2. 2.]

 [3. 3. 3. 3.]

 [4. 4. 4. 4.]

 [5. 5. 5. 5.]

 [6. 6. 6. 6.]

 [7. 7. 7. 7.]]

arr[[-2,-1,-3]]:

 [[6. 6. 6. 6.]

 [7. 7. 7. 7.]

 [5. 5. 5. 5.]]

3、获得指定的元素位置构成的子集

元素的位置有多个索引数组给出,其结果将返回一个一维数组

下面例子中,索引分别为[2,5] [1,3] [3,2]

 

import numpy as np
arr = np.arange(32).reshape((4,8))
print("arr:\n",arr)
print("arr[[2,1,3],[5,3,2]]:\n",arr[[2,1,3],[5,3,2]])

结果如下:

arr:

 [[ 0  1  2  3  4  5  6  7]

 [ 8  9 10 11 12 13 14 15]

 [16 17 18 19 20 21 22 23]

 [24 25 26 27 28 29 30 31]]

arr[[2,1,3],[5,3,2]]:

 [21 11 26]

下面例子中,索引分别为[0,2] [1,3] [2,4] [3,5]

import numpy as np
arr = np.arange(32).reshape((4,8))
print("arr:\n",arr)
print("arr[range(4),[2,3,4,5]]:\n",arr[range(4),[2,3,4,5]])

结果为:

arr:

 [[ 0  1  2  3  4  5  6  7]

 [ 8  9 10 11 12 13 14 15]

 [16 17 18 19 20 21 22 23]

 [24 25 26 27 28 29 30 31]]

arr[range(4),[2,3,4,5]]:

 [ 2 11 20 29]

 

二、tensorflow中的神奇索引

在下面的示例中,将在tensorflow中实现类似于arr[range(4),[2,3,4,5]](即arr[[0,1,2,3],[2,3,4,5]])的表达

首先,在tensorflow中可以实现单个索引值的索引,如下所示:

import  tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
 [11,12,13,14,15,16,17,18,19,20],
 [21,22,23,24,25,26,27,28,29,30],
 [31,32,33,34,35,36,37,38,39,40],])
newarr=arr[1,3]
with tf.Session() as sess:
    print(sess.run(newarr))

结果:

14

尝试使用多个索引值

import  tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
 [11,12,13,14,15,16,17,18,19,20],
 [21,22,23,24,25,26,27,28,29,30],
 [31,32,33,34,35,36,37,38,39,40],])
newarr=arr[tf.range(4),[2,3,4,5]]
with tf.Session() as sess:
    print(sess.run(newarr))

结果将报错

  File "F:\softwareInstall\program\anaconda\lib\site-packages\tensorflow\python\ops\array_ops.py", line 618, in _slice_helper

    _check_index(s)

  File "F:\softwareInstall\program\anaconda\lib\site-packages\tensorflow\python\ops\array_ops.py", line 516, in _check_index

    raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got

看来在tensorflow中不支持直接使用整数数组进行索引

查询到tf.gather_nd可以实现获得指定多个索引值

参考:https://blog.csdn.net/orangefly0214/article/details/81634310

import  tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
 [11,12,13,14,15,16,17,18,19,20],
 [21,22,23,24,25,26,27,28,29,30],
 [31,32,33,34,35,36,37,38,39,40],])
newarr=tf.gather_nd(arr,[[0,2],[1,3],[2,4],[3,5]])
with tf.Session() as sess:
    print(sess.run(newarr))

结果:

[ 3 14 25 36]

 尽管能够实现多个索引值,但是表达方式与所需不同,关键是如何将[0,1,2,3],[2,3,4,5]变为[[0,2],[1,3],[2,4],[3,5]],tensorflow中使用tf.stack和tf.unstack完成

参考:https://www.jianshu.com/p/25706575f8d4

import  tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
 [11,12,13,14,15,16,17,18,19,20],
 [21,22,23,24,25,26,27,28,29,30],
 [31,32,33,34,35,36,37,38,39,40],])
row=tf.range(4)
colum=tf.constant([2,3,4,5])
ss=tf.stack([row,colum],axis=0)#构成[0,1,2,3],[2,3,4,5]
indexs=tf.unstack(ss,axis=1)#构成[[0,2],[1,3],[2,4],[3,5]]
newarr=tf.gather_nd(arr,indexs)
with tf.Session() as sess:
    print(sess.run(newarr))
结果为:
[ 3 14 25 36]

你可能感兴趣的:(tensorflow使用神奇索引(整数数组索引))