tensorflow多维数据排序及tf.gather_nd函数使用

import tensorflow as tf

data=tf.constant(
         [[[[ 1.4179282 ,  1.5703515 , -0.61169857],
         [ 0.33743155, -0.40962902,  1.4966061 ],
         [ 1.9049675 ,  0.09397563,  1.0175595 ],
         [-0.35192606, -0.75464   ,  0.57980216],
         [-0.43845135,  1.1044604 ,  0.22715685]],

        [[ 1.3074918 ,  0.24285197, -0.31900284],
         [ 0.3282182 ,  0.59630245,  0.4908484 ],
         [-0.22814241, -0.34542274, -1.8534657 ],
         [ 0.6861197 ,  0.987247  ,  0.5381531 ],
         [-1.0340213 , -0.89111334, -0.6680704 ]],

        [[-0.0890333 ,  0.65065485, -0.80708045],
         [ 0.35345015, -0.15637785, -1.495281  ],
         [ 0.30774483,  0.13599548, -0.08656104],
         [ 0.15281916, -0.55360866,  1.001526  ],
         [ 0.08779743,  1.4143836 , -0.32160324]],

        [[-0.05001006, -0.56332934,  1.0221127 ],
         [ 0.14122595, -0.9398476 ,  0.9267709 ],
         [ 0.46681744, -0.26380906,  0.43542302],
         [-1.5392816 ,  0.5043589 , -0.9014659 ],
         [-0.5400769 ,  1.0968751 , -0.09318246]]],


       [[[-1.1827252 , -1.4649743 , -0.23750035],
         [-1.5741025 ,  0.63433784,  1.1028291 ],
         [ 0.62873596, -0.43399343,  0.8999915 ],
         [-0.5396441 , -0.8236998 , -0.3835167 ],
         [-0.48167858,  1.3502644 , -0.03549745]],

        [[ 0.91432065,  0.55631614,  0.9710358 ],
         [ 0.45699129, -1.0502417 ,  0.549892  ],
         [-1.0907862 ,  0.3600453 , -0.35341766],
         [ 0.26944   ,  0.4950551 ,  0.44320667],
         [-0.3407113 ,  0.5147896 ,  1.1087974 ]],

        [[-0.02270273, -0.36483103, -0.6037729 ],
         [ 0.08318438, -0.8092938 ,  0.95797205],
         [-0.728933  , -0.7125127 ,  0.843991  ],
         [-0.04112805,  0.66545516,  0.99063873],
         [-0.82321507, -0.64202845,  0.46515438]],

        [[-0.01119158,  0.03578063,  1.0805527 ],
         [-1.2769296 ,  0.45997906, -0.9196354 ],
         [ 0.03248561, -0.36549515, -0.73439956],
         [ 0.14434706,  1.3091575 ,  0.4094675 ],
         [-0.33830887, -1.4398551 , -0.4993919 ]]]])
# data = tf.random_normal([32,1024,40,16])
#data= tf.truncated_normal([2,4,5,3]) shape
batch_size = data.get_shape().as_list()[0]
dim2_size = data.get_shape().as_list()[1]
k=2 # 要取的个数

a2 = tf.reduce_sum(data,axis=-1)  # 对每行数据进行求和
index = tf.nn.top_k(a2, k=k).indices  # 这里只取出最大的两个向量,算出最大行的索引(2,4,2)相当于取16个数
index = tf.reshape(index,[-1,1]) #(16,1)

batch_idx = tf.range(0, batch_size) #  0-2
batch_idx = tf.reshape(batch_idx, (-1, 1))
batch_idx_tile = tf.tile(batch_idx, (1,dim2_size*k))
batch_idx_tile = tf.reshape(batch_idx_tile,[-1,1])  #(16,1)

dim2 = tf.range(0,dim2_size)
dim2 = tf.reshape(dim2, (dim2_size, 1)) #(4,1)
dim2 = tf.tile(dim2,[1,batch_size]) #(4,2)
dim2 = tf.reshape(dim2,[-1,1]) # (8,1)
dim2 = tf.concat([dim2]*k,axis=0) # (16,1)


new_index = tf.concat([batch_idx_tile,dim2, index], axis=-1)
new_index = tf.reshape(new_index,[batch_size,dim2_size,k,-1])

reordered = tf.gather_nd(data,new_index)  # 在原始数据中取出最大的k个数据

with tf.Session() as sess:

    print(sess.run(a2))
    print(sess.run(dim2))
    print(sess.run(index))
    print(sess.run(reordered))

本实验先生成了[2,4,5,3]维固定的数据,然后在该数据上获取最大的k=2维数据,也就是[2,4,2,3]维数据,16个数据每个数据3维。 获取四维数据,需要一个三维索引,所以先采用tf.nn.top_k获取索引(2,4,2)。由于需要获取具体的数据,比如索引为[0,0,0]的数也就是[ 1.4179282 , 1.5703515 , -0.61169857]。所以我们需要[16,3]维索引。在这里,我把这个索引分成三部分,第一部分为batch_index,里面有8个0和8个1。第二维为
[[0]
[0]
[1]
[1]
[2]
[2]
[3]
[3]
[0]
[0]
[1]
[1]
[2]
[2]
[3]
[3]]
重复两次,第三维为求得的index,最后concat连接起来就是所求的索引,再用tf.gather_nd即可。

实验结果:
[[[ 2.3765812   1.4244087   3.0165029  -0.5267639   0.8931658 ]
  [ 1.2313409   1.415369   -2.4270308   2.2115197  -2.593205  ]
  [-0.2454589  -1.2982087   0.35717928  0.6007365   1.1805778 ]
  [ 0.4087733   0.12814927  0.63843143 -1.9363886   0.46361572]]

 [[-2.8851998   0.16306442  1.0947341  -1.7468606   0.8330884 ]
  [ 2.4416726  -0.04335839 -1.0841585   1.2077018   1.2828758 ]
  [-0.99130666  0.2318626  -0.5974546   1.6149659  -1.0000892 ]
  [ 1.1051418  -1.736586   -1.067409    1.862972   -2.277556  ]]]
[[0]
 [0]
 [1]
 [1]
 [2]
 [2]
 [3]
 [3]
 [0]
 [0]
 [1]
 [1]
 [2]
 [2]
 [3]
 [3]]
[[2]
 [0]
 [3]
 [1]
 [4]
 [3]
 [2]
 [4]
 [2]
 [4]
 [0]
 [4]
 [3]
 [1]
 [3]
 [0]]
[[[[ 1.9049675   0.09397563  1.0175595 ]
   [ 1.4179282   1.5703515  -0.61169857]]

  [[ 0.6861197   0.987247    0.5381531 ]
   [ 0.3282182   0.59630245  0.4908484 ]]

  [[ 0.08779743  1.4143836  -0.32160324]
   [ 0.15281916 -0.55360866  1.001526  ]]

  [[ 0.46681744 -0.26380906  0.43542302]
   [-0.5400769   1.0968751  -0.09318246]]]


 [[[ 0.62873596 -0.43399343  0.8999915 ]
   [-0.48167858  1.3502644  -0.03549745]]

  [[ 0.91432065  0.55631614  0.9710358 ]
   [-0.3407113   0.5147896   1.1087974 ]]

  [[-0.04112805  0.66545516  0.99063873]
   [ 0.08318438 -0.8092938   0.95797205]]

  [[ 0.14434706  1.3091575   0.4094675 ]
   [-0.01119158  0.03578063  1.0805527 ]]]]

你可能感兴趣的:(深度学习)