import tensorflow as tf
data=tf.constant(
[[[[ 1.4179282 , 1.5703515 , -0.61169857],
[ 0.33743155, -0.40962902, 1.4966061 ],
[ 1.9049675 , 0.09397563, 1.0175595 ],
[-0.35192606, -0.75464 , 0.57980216],
[-0.43845135, 1.1044604 , 0.22715685]],
[[ 1.3074918 , 0.24285197, -0.31900284],
[ 0.3282182 , 0.59630245, 0.4908484 ],
[-0.22814241, -0.34542274, -1.8534657 ],
[ 0.6861197 , 0.987247 , 0.5381531 ],
[-1.0340213 , -0.89111334, -0.6680704 ]],
[[-0.0890333 , 0.65065485, -0.80708045],
[ 0.35345015, -0.15637785, -1.495281 ],
[ 0.30774483, 0.13599548, -0.08656104],
[ 0.15281916, -0.55360866, 1.001526 ],
[ 0.08779743, 1.4143836 , -0.32160324]],
[[-0.05001006, -0.56332934, 1.0221127 ],
[ 0.14122595, -0.9398476 , 0.9267709 ],
[ 0.46681744, -0.26380906, 0.43542302],
[-1.5392816 , 0.5043589 , -0.9014659 ],
[-0.5400769 , 1.0968751 , -0.09318246]]],
[[[-1.1827252 , -1.4649743 , -0.23750035],
[-1.5741025 , 0.63433784, 1.1028291 ],
[ 0.62873596, -0.43399343, 0.8999915 ],
[-0.5396441 , -0.8236998 , -0.3835167 ],
[-0.48167858, 1.3502644 , -0.03549745]],
[[ 0.91432065, 0.55631614, 0.9710358 ],
[ 0.45699129, -1.0502417 , 0.549892 ],
[-1.0907862 , 0.3600453 , -0.35341766],
[ 0.26944 , 0.4950551 , 0.44320667],
[-0.3407113 , 0.5147896 , 1.1087974 ]],
[[-0.02270273, -0.36483103, -0.6037729 ],
[ 0.08318438, -0.8092938 , 0.95797205],
[-0.728933 , -0.7125127 , 0.843991 ],
[-0.04112805, 0.66545516, 0.99063873],
[-0.82321507, -0.64202845, 0.46515438]],
[[-0.01119158, 0.03578063, 1.0805527 ],
[-1.2769296 , 0.45997906, -0.9196354 ],
[ 0.03248561, -0.36549515, -0.73439956],
[ 0.14434706, 1.3091575 , 0.4094675 ],
[-0.33830887, -1.4398551 , -0.4993919 ]]]])
# data = tf.random_normal([32,1024,40,16])
#data= tf.truncated_normal([2,4,5,3]) shape
batch_size = data.get_shape().as_list()[0]
dim2_size = data.get_shape().as_list()[1]
k=2 # 要取的个数
a2 = tf.reduce_sum(data,axis=-1) # 对每行数据进行求和
index = tf.nn.top_k(a2, k=k).indices # 这里只取出最大的两个向量,算出最大行的索引(2,4,2)相当于取16个数
index = tf.reshape(index,[-1,1]) #(16,1)
batch_idx = tf.range(0, batch_size) # 0-2
batch_idx = tf.reshape(batch_idx, (-1, 1))
batch_idx_tile = tf.tile(batch_idx, (1,dim2_size*k))
batch_idx_tile = tf.reshape(batch_idx_tile,[-1,1]) #(16,1)
dim2 = tf.range(0,dim2_size)
dim2 = tf.reshape(dim2, (dim2_size, 1)) #(4,1)
dim2 = tf.tile(dim2,[1,batch_size]) #(4,2)
dim2 = tf.reshape(dim2,[-1,1]) # (8,1)
dim2 = tf.concat([dim2]*k,axis=0) # (16,1)
new_index = tf.concat([batch_idx_tile,dim2, index], axis=-1)
new_index = tf.reshape(new_index,[batch_size,dim2_size,k,-1])
reordered = tf.gather_nd(data,new_index) # 在原始数据中取出最大的k个数据
with tf.Session() as sess:
print(sess.run(a2))
print(sess.run(dim2))
print(sess.run(index))
print(sess.run(reordered))
本实验先生成了[2,4,5,3]维固定的数据,然后在该数据上获取最大的k=2维数据,也就是[2,4,2,3]维数据,16个数据每个数据3维。 获取四维数据,需要一个三维索引,所以先采用tf.nn.top_k获取索引(2,4,2)。由于需要获取具体的数据,比如索引为[0,0,0]的数也就是[ 1.4179282 , 1.5703515 , -0.61169857]。所以我们需要[16,3]维索引。在这里,我把这个索引分成三部分,第一部分为batch_index,里面有8个0和8个1。第二维为
[[0]
[0]
[1]
[1]
[2]
[2]
[3]
[3]
[0]
[0]
[1]
[1]
[2]
[2]
[3]
[3]]
重复两次,第三维为求得的index,最后concat连接起来就是所求的索引,再用tf.gather_nd即可。
实验结果:
[[[ 2.3765812 1.4244087 3.0165029 -0.5267639 0.8931658 ]
[ 1.2313409 1.415369 -2.4270308 2.2115197 -2.593205 ]
[-0.2454589 -1.2982087 0.35717928 0.6007365 1.1805778 ]
[ 0.4087733 0.12814927 0.63843143 -1.9363886 0.46361572]]
[[-2.8851998 0.16306442 1.0947341 -1.7468606 0.8330884 ]
[ 2.4416726 -0.04335839 -1.0841585 1.2077018 1.2828758 ]
[-0.99130666 0.2318626 -0.5974546 1.6149659 -1.0000892 ]
[ 1.1051418 -1.736586 -1.067409 1.862972 -2.277556 ]]]
[[0]
[0]
[1]
[1]
[2]
[2]
[3]
[3]
[0]
[0]
[1]
[1]
[2]
[2]
[3]
[3]]
[[2]
[0]
[3]
[1]
[4]
[3]
[2]
[4]
[2]
[4]
[0]
[4]
[3]
[1]
[3]
[0]]
[[[[ 1.9049675 0.09397563 1.0175595 ]
[ 1.4179282 1.5703515 -0.61169857]]
[[ 0.6861197 0.987247 0.5381531 ]
[ 0.3282182 0.59630245 0.4908484 ]]
[[ 0.08779743 1.4143836 -0.32160324]
[ 0.15281916 -0.55360866 1.001526 ]]
[[ 0.46681744 -0.26380906 0.43542302]
[-0.5400769 1.0968751 -0.09318246]]]
[[[ 0.62873596 -0.43399343 0.8999915 ]
[-0.48167858 1.3502644 -0.03549745]]
[[ 0.91432065 0.55631614 0.9710358 ]
[-0.3407113 0.5147896 1.1087974 ]]
[[-0.04112805 0.66545516 0.99063873]
[ 0.08318438 -0.8092938 0.95797205]]
[[ 0.14434706 1.3091575 0.4094675 ]
[-0.01119158 0.03578063 1.0805527 ]]]]