Python小技巧 - argmax

argmax 返回的是输入列表中最大值的位置,其重要性不必多言,但是据我所知 Python 自带的库中只提供 max 这个函数,并没有 argmax,但是实现起来不难。

Numpy 中的 argmax

首先我们先来看一下 Numpy 中提供 argmax 函数,它重要的特点就是在有多个最大值的情况下,只返回第一个出现的最大值的位置。

In [1]: import numpy as np
In [2]: a = [1, 2, 9, 2, 5, 6, 9]
In [3]: np.argmax(a)
Out[3]: 2

如果需要返回所有最大值的位置的话,还是要麻烦一下的:

In [4]: np.where(a == np.max(a))
Out[4]: (array([2, 6], dtype=int64),)

自己写

看来 Numpy 中要找所有最大值的位置也得小小麻烦一下。如果不能使用外带的库,其实自己写一下也是很简单的。

In [5]: [i for i, val in enumerate(a) if (val == max(a))]
Out[5]: [2, 6]

可以返回所有最大值的位置,得益于 Python 的灵活,很简单也很方便是吧。 而且这么做还有一个好处,如果有时候因为 Python 内部数值计算的原因,我们得到的结果是 0.999999999999 ,其实和最大值 1 没区别,但是如果使用 [5] 行中的代码就只会返回最大值 1 的坐标,怎么才能一同返回 0.999999999999 的坐标呢?

In [1]: import math
In [2]: a = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
In [3]: sum(a)
Out[3]: 0.9999999999999999
In [4]: b = [1, 0.1, -1, 0.9999999999999999, -1, 0.999, 0.996]
In [5]: [i for i, val in enumerate(b) if (val == max(b))]
Out[5]: [0]
In [6]: [i for i, val in enumerate(b) if math.isclose(max(b), val, rel_tol = 1e-08)]
Out[6]: [0, 3]

不过提醒一下大家,math.islcose() 只有 Python 3.5 及以上版本才有,使用之前看清楚自己的版本哦。

你可能感兴趣的:(Python)