神奇索引是Numpy中的术语,用于描述使用整数数组进行数据索引。神奇索引实现的是给定多个矩阵的索引位置(使用列表或数组指定),获得这些指定位置的值,该方法不同于普通的切片索引,能够离散的获得矩阵元素的值,神奇索引的具体描述可参考《利用python进行数据分析》一书的第四章。
import numpy as np
arr = np.empty((8,4))
for i in range(8):
arr[i] = i
print("arr:\n",arr)
print("arr[[2,1,3]]:\n",arr[[2,1,3]])
结果为:
arr:
[[0. 0. 0. 0.]
[1. 1. 1. 1.]
[2. 2. 2. 2.]
[3. 3. 3. 3.]
[4. 4. 4. 4.]
[5. 5. 5. 5.]
[6. 6. 6. 6.]
[7. 7. 7. 7.]]
arr[[2,1,3]]:
[[2. 2. 2. 2.]
[1. 1. 1. 1.]
[3. 3. 3. 3.]]
import numpy as np
arr = np.empty((8,4))
for i in range(8):
arr[i] = i
print("arr:\n",arr)
print("arr[[-2,-1,-3]]:\n",arr[[-2,-1,-3]])
结果为:
arr:
[[0. 0. 0. 0.]
[1. 1. 1. 1.]
[2. 2. 2. 2.]
[3. 3. 3. 3.]
[4. 4. 4. 4.]
[5. 5. 5. 5.]
[6. 6. 6. 6.]
[7. 7. 7. 7.]]
arr[[-2,-1,-3]]:
[[6. 6. 6. 6.]
[7. 7. 7. 7.]
[5. 5. 5. 5.]]
元素的位置有多个索引数组给出,其结果将返回一个一维数组
下面例子中,索引分别为[2,5] [1,3] [3,2]
import numpy as np
arr = np.arange(32).reshape((4,8))
print("arr:\n",arr)
print("arr[[2,1,3],[5,3,2]]:\n",arr[[2,1,3],[5,3,2]])
结果如下:
arr:
[[ 0 1 2 3 4 5 6 7]
[ 8 9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29 30 31]]
arr[[2,1,3],[5,3,2]]:
[21 11 26]
下面例子中,索引分别为[0,2] [1,3] [2,4] [3,5]
import numpy as np
arr = np.arange(32).reshape((4,8))
print("arr:\n",arr)
print("arr[range(4),[2,3,4,5]]:\n",arr[range(4),[2,3,4,5]])
结果为:
arr:
[[ 0 1 2 3 4 5 6 7]
[ 8 9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29 30 31]]
arr[range(4),[2,3,4,5]]:
[ 2 11 20 29]
在下面的示例中,将在tensorflow中实现类似于arr[range(4),[2,3,4,5]](即arr[[0,1,2,3],[2,3,4,5]])的表达
首先,在tensorflow中可以实现单个索引值的索引,如下所示:
import tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
[11,12,13,14,15,16,17,18,19,20],
[21,22,23,24,25,26,27,28,29,30],
[31,32,33,34,35,36,37,38,39,40],])
newarr=arr[1,3]
with tf.Session() as sess:
print(sess.run(newarr))
结果:
14
尝试使用多个索引值
import tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
[11,12,13,14,15,16,17,18,19,20],
[21,22,23,24,25,26,27,28,29,30],
[31,32,33,34,35,36,37,38,39,40],])
newarr=arr[tf.range(4),[2,3,4,5]]
with tf.Session() as sess:
print(sess.run(newarr))
结果将报错
File "F:\softwareInstall\program\anaconda\lib\site-packages\tensorflow\python\ops\array_ops.py", line 618, in _slice_helper
_check_index(s)
File "F:\softwareInstall\program\anaconda\lib\site-packages\tensorflow\python\ops\array_ops.py", line 516, in _check_index
raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got
看来在tensorflow中不支持直接使用整数数组进行索引
查询到tf.gather_nd可以实现获得指定多个索引值
参考:https://blog.csdn.net/orangefly0214/article/details/81634310
import tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
[11,12,13,14,15,16,17,18,19,20],
[21,22,23,24,25,26,27,28,29,30],
[31,32,33,34,35,36,37,38,39,40],])
newarr=tf.gather_nd(arr,[[0,2],[1,3],[2,4],[3,5]])
with tf.Session() as sess:
print(sess.run(newarr))
结果:
[ 3 14 25 36]
尽管能够实现多个索引值,但是表达方式与所需不同,关键是如何将[0,1,2,3],[2,3,4,5]变为[[0,2],[1,3],[2,4],[3,5]],tensorflow中使用tf.stack和tf.unstack完成
参考:https://www.jianshu.com/p/25706575f8d4
import tensorflow as tf
arr=tf.constant([[1,2,3,4,5,5,7,8,9,10],
[11,12,13,14,15,16,17,18,19,20],
[21,22,23,24,25,26,27,28,29,30],
[31,32,33,34,35,36,37,38,39,40],])
row=tf.range(4)
colum=tf.constant([2,3,4,5])
ss=tf.stack([row,colum],axis=0)#构成[0,1,2,3],[2,3,4,5]
indexs=tf.unstack(ss,axis=1)#构成[[0,2],[1,3],[2,4],[3,5]]
newarr=tf.gather_nd(arr,indexs)
with tf.Session() as sess:
print(sess.run(newarr))
结果为:
[ 3 14 25 36]