楔子
Numpy 里面有一些以 arg 开头的函数,非常的有意思,因为它们返回的不是元素、而是元素的索引,下面来看一下用法。
np.argmax
我们知道 np.max 是获取最大的元素,那么np.argmax是做什么的呢?
import numpy as np arr = np.array([3, 22, 4, 11, 2, 44, 9]) print(np.max(arr)) # 44 print(np.argmax(arr)) # 5
结论很清晰,np.argmax 是获取最大元素所在的索引。
同理还有 np.argmin,np.min 是获取数组中最小的元素,显然是2;np.argmin 是获取数组中最小元素对应的索引,显然是 4。
import numpy as np arr = np.array([3, 22, 4, 11, 2, 44, 9]) print(np.min(arr)) # 2 print(np.argmin(arr)) # 4
那么问题来了,如果我们想自己实现一个 argmax 和 argmin 该怎么做呢?
def argmax(lst): lst = [ # 将每个元素和对应的索引组合起来 (item, idx) for idx, item in enumerate(lst) ] return max(lst)[1] print( argmax([3, 22, 4, 11, 2, 44, 9]) ) # 5
还是很简单的,如果实现 argmin 的话,只需将里面的 max 函数换成 min 即可。
np.argwhere
np.where 还是用的挺频繁的,先来复习一下它的用法吧。
import numpy as np arr = np.array( [1, 2, 3, 4, 5, 6, 7, 8]) # 如果元素大于 4,那么减去 10,否则乘以 10 print( np.where(arr > 4, arr - 10, arr * 10) ) # [10 20 30 40 -5 -4 -3 -2] # 如果元素大于4, 那么保持不变, 否则变成 4 print( np.where(arr > 4, arr, 4) ) # [4 4 4 4 5 6 7 8]
和 np.where 作用类似的还有一个 np.clip,来看一下。
import numpy as np arr = np.array([1, 2, 3, 4, 5, 6, 7, 8]) # 小于 2 的换成 2, 大于 6 的换成 6 # 一般在设置上下限的时候非常有用 print( np.clip(arr, 2, 6) ) # [2 2 3 4 5 6 6 6]
那么 np.argwhere 是做啥的呢?首先这个函数只接受一个参数,找出满足条件的元素对应的索引。
import numpy as np arr = np.array([3, 4, 5, 6, 7]) # 显然 3、5、7 对 2 取模不等于 0 # 它们对应的索引分别是 0、2、4 print(np.argwhere(arr % 2 != 0)) """ [[0] [2] [4]] """ # 需要扁平化 print( np.argwhere(arr % 2 != 0).flatten() ) # [0 2 4]
显然元素 3、5、7 在 %2 之后不为 0,所以会筛选出它们的索引,因此是 [0 2 4]。只不过默认不是一个一维数组,我们需要再调用一下 flatten,将其扁平化。
np.argsort
np.sort 是用来排序的,类似于内置函数 sorted。
import numpy as np arr = np.array([44, 22, 33, 66, 55, 11]) print( np.sort(arr) ) # [11 22 33 44 55 66]
sort 很容易,再来看看 argsort。
import numpy as np arr = np.array([44, 22, 33, 66, 55, 11]) print( np.sort(arr) ) # [11 22 33 44 55 66] print( np.argsort(arr) ) # [5 1 2 0 4 3]
所以 argsort 返回的是元素排序后对应的索引,光说可能不好理解,我们画一张图就清晰了。
因此 sort 是对元素进行排序,而 argsort 是按照元素的大小对索引进行排序。比如我们想选择数组中最大的 3 个元素,就有两种做法。
import numpy as np arr = np.array([44, 22, 33, 66, 55, 11]) # 选择前三个最大元素,可以先全局排序 # 然后选择前三个 print( np.sort(arr)[-3:] ) # [44 55 66] # 使用 argsort,找到最大的三个元素的索引 # 然后通过 arr 进行筛选 print( arr[np.argsort(arr)[-3:]] ) # [44 55 66]
另外 argsort 还可以实现一个特殊的需求,那就是排名统计,举个例子:[44, 22, 33, 66, 55, 11],如果按照从小到大排名,
- 44 在 array 中排名第 4;
- 22 在 array 中排名第 2;
- 33 在 array 中排名第 3;
- 66 在 array 中排名第 6;
- 55 在 array 中排名第 5;
- 11 在 array 中排名第 1;
所以需要实现一个功能,将 arr 传进去之后,返回 [4 2 3 6 5 1]。如果是你的话,你会怎么做呢?最容易想到的办法就是遍历:
lst = [44, 22, 33, 66, 55, 11] # 先排序 sorted_lst = sorted(lst) # 依次查看 lst 里面的每个元素,在排序后列表中的位置 # 但索引从 0 开始,因此需要再加上 1 rank = [sorted_lst.index(item) + 1 for item in lst] print(rank) # [4, 2, 3, 6, 5, 1]
但是这段代码有缺陷,那就是 lst 里面的元素不能重复,这里我们假设不重复。如果用 argsort 该怎么做呢?
import numpy as np arr = np.array([44, 22, 33, 66, 55, 11]) print( np.argsort(arr) ) # [5 1 2 0 4 3] print( np.argsort(np.argsort(arr)) ) # [3 1 2 5 4 0] print( np.argsort(np.argsort(arr)) + 1 ) # [4 2 3 6 5 1] # np.argsort(arr) 等价于 arr.argsort()
使用 Numpy 的话两次 argsort 即可完成,这个过程可能有点绕,为了方便理解,我们画一张图:
可以对照着图多理解一下。
np.argpartition
介绍 argpartiiton 之前,先来说一下 partition,它和 np.sort 类似,都是用于排序,但 partiton 是局部排序。
import numpy as np arr = np.array([66, 15, 27, 33, 19, 13, 10]) """ np.partition(arr, k) 返回一个新数组 并且索引为 k 的元素比它左边的元素都大,比它右边的元素都小 """ print( np.partition(arr, 3) ) # [15 13 10 19 27 33 66] # 返回的数组中索引为 3 的元素是 19 # 比它左边的元素都大,比它右边的元素都小 # 至于左右两边的顺序则没有要求
当我们想实现 topK 的时候,这个方法非常适合。
虽然我们可以使用 np.sort 排序之后再通过切片截取, 但 sort 是全局排序。如果数组非常大, 我们只希望选择最小的 10 个元素,那么全局排序就有点浪费了。而最好的方式是通过 np.partition(arr, 9),然后将前 10 个元素截取出来进行排序即可,而无需对整个大数组进行排序。
结果是一样的,但是两者的效率如何呢?我们来对比一下:
我们看到效率差的不是一点点,原因就在于 partition 是一个时间复杂度为 O(N) 算法。因此在处理 topK 问题时,如果 K 远小于长度 N,那么使用 partition 是最佳选择。
另外,partition 的第二个参数还可以为负数:
import numpy as np arr = np.array([66, 15, 27, 33, 19, 13, 10]) # 让数组中索引为 -2 的元素比它左边的元素都大, # 比它右边的元素都小,显然这是用来筛选前 K 个最大的元素 print( np.partition(arr, -2) ) # [15 13 10 19 27 33 66] # 比如筛选前 3 个最大的元素 print( np.partition(arr, -3)[-3:] ) # [27 33 66]
partition 函数的用法我们就知道了,而 argpartition 也和 argsort 类似,只不过返回的是对应的索引。可以自己编写代码测试一下。
以上就是 arg 开头的几个函数的用法。