一、简介
def top_k(input, k=1, sorted=True, name=None)
Finds values and indices of the k
largest entries for the last dimension.
If the input is a vector (rank=1), finds the k
largest entries in the vector and outputs their values and indices as vectors.Thus values[j]
is the j
-th largest entry in input
, and its index is indices[j]
.
For matrices (resp. higher rank input), computes the top k
entries in each row (resp. vector along the last dimension).Thus, values.shape = indices.shape = input.shape[:-1] + [k]
If two elements are equal, the lower-index element appears first.
翻译:
查找最后一个维度的前k
个最大条目的值和索引。
如果输入是向量(rank = 1),则在向量中找到前k
个最大条目,并将它们的值和索引作为向量输出。因此values[j]
是 input
中的第j
个最大条目,它的索引是indices [j]
。
对于矩阵(或更高级别的输入),计算每行中的顶部k
条目(沿最后一个维度的每一个向量)。因此,values.shape = indices.shape = input.shape [: - 1] + [K]
如果两个元素相等,则首先显示lower-index(索引值较低)元素。
注:该函数返回的数据包含两部分,第一部分是返回的value值,第二部分返回的是对应的索引值。可以通过索引[0]或者[1]进行访问。
二、参数
参数 | ||
---|---|---|
input |
1-D or higher Tensor with last dimension at least k . |
一个一维或者更高维度的Tensor,他的最后一维的数目至少为k。 |
k |
0-D int32 Tensor . Number of top elements to look for along the last dimension (along each row for matrices). |
一个整形Tensor,表示沿着最后一个维度去寻找的元素的数目。 |
sorted |
If true the resulting k elements will be sorted by the values in descending order. |
如果该值被设置为True,则返回的k个元素会被按照值的从大到小的顺序进行排序。默认为True。 |
name |
Optional name for the operation. | 可选参数,名称 |
三、代码
import tensorflow as tf
import numpy as np
# 建立一个长度为10的向量,内部数据随机生成。
a = tf.convert_to_tensor(np.random.random([10]))
# 取出前5个最大的数据,默认从大到小进行排序。
b = tf.nn.top_k(a, 5)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(b))
print(sess.run(b[1]))
运行结果:
[0.09673178 0.2011694 0.77118243 0.20476724 0.3439558 0.69864978
0.2118251 0.32904677 0.87435634 0.47136589]
TopKV2(values=array([0.87435634, 0.77118243, 0.69864978, 0.47136589, 0.3439558 ]), indices=array([8, 2, 5, 9, 4]))
[8 2 5 9 4]
当传入更高维度的数据时:
import tensorflow as tf
import numpy as np
# 定义一个三维的矩阵,内部数据随机产生
a = tf.convert_to_tensor(np.random.random([20, 20, 10]))
# 按照最后一个维度取出前5个最大的数据,默认从大到小进行排序。
b = tf.nn.top_k(a, 5)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(b))
print(sess.run(b)[1].shape)
运行结果: 最值得注意的是最后一个返回值的shape,只有最后一个维度有所区别。
[[[0.8498202 0.05195572 0.8849565 ... 0.66397947 0.54824224 0.74318886]
[0.49996231 0.91040108 0.21483549 ... 0.04122947 0.64088468 0.32510497]
[0.90725498 0.68344152 0.43061874 ... 0.39102586 0.12769082 0.66023738]
...
[0.60666856 0.56439855 0.28063549 ... 0.93124743 0.89449678 0.66979802]
[0.76200935 0.06834749 0.85145249 ... 0.67836563 0.01516219 0.01993689]
[0.47049275 0.50707521 0.36991098 ... 0.88998056 0.12763079 0.09845498]]
[[0.57171017 0.15238957 0.08806684 ... 0.02480321 0.48453851 0.85199458]
[0.35878106 0.30580091 0.22070303 ... 0.42346321 0.22950292 0.18906091]
[0.90136589 0.41240145 0.52366428 ... 0.69907391 0.26080453 0.19672214]
...
[0.39987234 0.93231962 0.02967131 ... 0.38570163 0.52938515 0.89505879]
[0.66779964 0.62346695 0.84506223 ... 0.57041431 0.12558373 0.75406602]
[0.04802938 0.96657687 0.07476398 ... 0.93957134 0.88229134 0.48934519]]
[[0.03006041 0.0136604 0.75244466 ... 0.65651256 0.39410724 0.83654045]
[0.71498666 0.56440115 0.95761964 ... 0.02704624 0.51868975 0.44324936]
[0.41980744 0.63474661 0.58030962 ... 0.20945427 0.29488566 0.07749595]
...
[0.11727653 0.9169551 0.02627972 ... 0.8763961 0.36451567 0.96754857]
[0.28255761 0.22505311 0.74507012 ... 0.23504345 0.20330998 0.04071097]
[0.73204599 0.50676066 0.0524236 ... 0.74684682 0.93345544 0.83705093]]
...
[[0.64496108 0.66815738 0.17245006 ... 0.43895167 0.89021163 0.65442853]
[0.8690804 0.44297673 0.48261915 ... 0.71620392 0.28584558 0.60172575]
[0.31634969 0.39460366 0.25693086 ... 0.93440372 0.50671148 0.2486601 ]
...
[0.71044313 0.32806087 0.70054147 ... 0.80219637 0.96946221 0.76465067]
[0.35188569 0.83711553 0.01343541 ... 0.28523762 0.45159021 0.81395335]
[0.52934446 0.23226338 0.28012356 ... 0.13028752 0.9962975 0.44482207]]
[[0.89439131 0.60870675 0.21073087 ... 0.62333398 0.52917202 0.69767772]
[0.94700397 0.14408882 0.96524112 ... 0.75613067 0.76415524 0.22070657]
[0.58182603 0.63138273 0.24297734 ... 0.01150216 0.91135157 0.56416608]
...
[0.73974793 0.93020208 0.82434553 ... 0.73215145 0.42041154 0.34463405]
[0.59814222 0.49599991 0.4764923 ... 0.27145421 0.87418982 0.70327742]
[0.61134091 0.96387942 0.31842696 ... 0.38037157 0.51440121 0.94851797]]
[[0.22655945 0.05248473 0.47943931 ... 0.45506608 0.32513959 0.04213444]
[0.33406586 0.34820628 0.59872586 ... 0.01636161 0.34377442 0.4370155 ]
[0.98888032 0.62710205 0.92201311 ... 0.27882558 0.46042077 0.4403413 ]
...
[0.49680129 0.41594056 0.93365285 ... 0.87372742 0.70665113 0.15976358]
[0.48933501 0.31931995 0.92455068 ... 0.76884526 0.3875951 0.12877622]
[0.16327613 0.35248604 0.90702435 ... 0.33775252 0.60606198 0.05021601]]]
TopKV2(values=array([[[0.8849565 , 0.8498202 , 0.74318886, 0.66397947, 0.54824224],
[0.91040108, 0.88671867, 0.64088468, 0.6227422 , 0.55252928],
[0.90725498, 0.80000614, 0.68344152, 0.66023738, 0.43061874],
...,
[0.93124743, 0.89449678, 0.84382658, 0.69857909, 0.66979802],
[0.9937199 , 0.85145249, 0.79742674, 0.76200935, 0.67836563],
[0.88998056, 0.82158075, 0.63521181, 0.63428801, 0.50707521]],
[[0.85199458, 0.62084884, 0.58125283, 0.57171017, 0.48453851],
[0.42346321, 0.4149812 , 0.41135642, 0.35878106, 0.30580091],
[0.90136589, 0.82570664, 0.77936049, 0.69907391, 0.65195422],
...,
[0.93231962, 0.89505879, 0.70673449, 0.65217635, 0.61258099],
[0.84506223, 0.77861116, 0.75406602, 0.71340457, 0.66779964],
[0.96657687, 0.93957134, 0.9157956 , 0.88229134, 0.71152801]],
[[0.83654045, 0.82614004, 0.75244466, 0.65651256, 0.5979959 ],
[0.97347234, 0.95761964, 0.74878137, 0.71498666, 0.56440115],
[0.69586341, 0.66677399, 0.63474661, 0.58030962, 0.50256754],
...,
[0.96754857, 0.9654071 , 0.9169551 , 0.8763961 , 0.75280918],
[0.89399932, 0.74507012, 0.72503987, 0.70364816, 0.30463687],
[0.93345544, 0.84302607, 0.83705093, 0.74684682, 0.73204599]],
...,
[[0.89021163, 0.87103578, 0.66815738, 0.65442853, 0.64496108],
[0.98479821, 0.8690804 , 0.76023822, 0.71620392, 0.60994919],
[0.99097209, 0.93440372, 0.50671148, 0.39460366, 0.31634969],
...,
[0.96946221, 0.80219637, 0.76465067, 0.71044313, 0.70137228],
[0.92060929, 0.83711553, 0.81395335, 0.45159021, 0.35812735],
[0.9962975 , 0.93294538, 0.59146199, 0.52934446, 0.44482207]],
[[0.89439131, 0.80978689, 0.78357729, 0.69767772, 0.62333398],
[0.97394143, 0.96524112, 0.94700397, 0.76415524, 0.75613067],
[0.99029969, 0.91135157, 0.63266259, 0.63138273, 0.58182603],
...,
[0.96092679, 0.93020208, 0.82434553, 0.79132076, 0.73974793],
[0.87418982, 0.7208272 , 0.70327742, 0.59814222, 0.49599991],
[0.96387942, 0.96196898, 0.94851797, 0.61134091, 0.51440121]],
[[0.56481086, 0.47943931, 0.45506608, 0.32513959, 0.31256481],
[0.74507118, 0.72854902, 0.59872586, 0.4370155 , 0.35918261],
[0.98888032, 0.92201311, 0.90066467, 0.81044143, 0.70324162],
...,
[0.97444993, 0.93365285, 0.87372742, 0.83351393, 0.80496823],
[0.92455068, 0.76884526, 0.62140043, 0.55610083, 0.48933501],
[0.90702435, 0.9067023 , 0.68597384, 0.60606198, 0.35248604]]]), indices=array([[[2, 0, 9, 7, 8],
[1, 4, 8, 6, 3],
[0, 4, 1, 9, 2],
...,
[7, 8, 6, 3, 9],
[3, 2, 5, 0, 7],
[7, 5, 3, 4, 1]],
[[9, 3, 5, 0, 8],
[7, 5, 6, 0, 1],
[0, 4, 6, 7, 5],
...,
[1, 9, 6, 3, 4],
[2, 5, 9, 4, 0],
[1, 7, 5, 8, 4]],
[[9, 3, 2, 7, 4],
[5, 2, 3, 0, 1],
[6, 5, 1, 2, 4],
...,
[9, 6, 1, 7, 4],
[6, 2, 3, 4, 5],
[8, 3, 9, 7, 0]],
...,
[[8, 5, 1, 9, 0],
[6, 0, 4, 7, 5],
[3, 7, 8, 1, 0],
...,
[8, 7, 9, 0, 6],
[4, 1, 9, 8, 3],
[8, 5, 3, 0, 9]],
[[0, 4, 5, 9, 7],
[5, 2, 0, 8, 7],
[5, 8, 6, 1, 0],
...,
[6, 1, 2, 5, 0],
[8, 5, 9, 0, 1],
[1, 5, 9, 0, 8]],
[[6, 2, 7, 8, 4],
[3, 6, 2, 9, 4],
[0, 2, 5, 3, 4],
...,
[3, 2, 7, 6, 5],
[2, 7, 6, 4, 0],
[2, 5, 3, 8, 1]]]))
(20, 20, 5)