Numpy 的一些以 arg 开头的函数

简介: Numpy 的一些以 arg 开头的函数


楔子



Numpy 里面有一些以 arg 开头的函数,非常的有意思,因为它们返回的不是元素、而是元素的索引,下面来看一下用法。


np.argmax



我们知道 np.max 是获取最大的元素,那么np.argmax是做什么的呢?

import numpy as np
arr = np.array([3, 22, 4, 11, 2, 44, 9])
print(np.max(arr))  # 44
print(np.argmax(arr))  # 5

结论很清晰,np.argmax 是获取最大元素所在的索引。

同理还有 np.argmin,np.min 是获取数组中最小的元素,显然是2;np.argmin 是获取数组中最小元素对应的索引,显然是 4。

import numpy as np
arr = np.array([3, 22, 4, 11, 2, 44, 9])
print(np.min(arr))  # 2
print(np.argmin(arr))  # 4

那么问题来了,如果我们想自己实现一个 argmax 和 argmin 该怎么做呢?

def argmax(lst):
    lst = [
        # 将每个元素和对应的索引组合起来
        (item, idx) for idx, item in enumerate(lst)
    ]
    return max(lst)[1]
print(
    argmax([3, 22, 4, 11, 2, 44, 9])
)  # 5

还是很简单的,如果实现 argmin 的话,只需将里面的 max 函数换成 min 即可。


np.argwhere



np.where 还是用的挺频繁的,先来复习一下它的用法吧。

import numpy as np
arr = np.array(
    [1, 2, 3, 4, 5, 6, 7, 8])
# 如果元素大于 4,那么减去 10,否则乘以 10
print(
    np.where(arr > 4, arr - 10, arr * 10)
)  # [10 20 30 40 -5 -4 -3 -2]
# 如果元素大于4, 那么保持不变, 否则变成 4
print(
    np.where(arr > 4, arr, 4)
)  # [4 4 4 4 5 6 7 8]

和 np.where 作用类似的还有一个 np.clip,来看一下。

import numpy as np
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])
# 小于 2 的换成 2, 大于 6 的换成 6
# 一般在设置上下限的时候非常有用
print(
    np.clip(arr, 2, 6)
)  # [2 2 3 4 5 6 6 6]

那么 np.argwhere 是做啥的呢?首先这个函数只接受一个参数,找出满足条件的元素对应的索引。

import numpy as np
arr = np.array([3, 4, 5, 6, 7])
# 显然 3、5、7 对 2 取模不等于 0
# 它们对应的索引分别是 0、2、4
print(np.argwhere(arr % 2 != 0))
"""
[[0]
 [2]
 [4]]
"""
# 需要扁平化
print(
    np.argwhere(arr % 2 != 0).flatten()
)  # [0 2 4]


显然元素 3、5、7 在 %2 之后不为 0,所以会筛选出它们的索引,因此是 [0 2 4]。只不过默认不是一个一维数组,我们需要再调用一下 flatten,将其扁平化。


np.argsort



np.sort 是用来排序的,类似于内置函数 sorted。

import numpy as np
arr = np.array([44, 22, 33, 66, 55, 11])
print(
    np.sort(arr)
)  # [11 22 33 44 55 66]

sort 很容易,再来看看 argsort。

import numpy as np
arr = np.array([44, 22, 33, 66, 55, 11])
print(
    np.sort(arr)
)  # [11 22 33 44 55 66] 
print(
    np.argsort(arr)
)  # [5 1 2 0 4 3]

所以 argsort 返回的是元素排序后对应的索引,光说可能不好理解,我们画一张图就清晰了。

因此 sort 是对元素进行排序,而 argsort 是按照元素的大小对索引进行排序。比如我们想选择数组中最大的 3 个元素,就有两种做法。

import numpy as np
arr = np.array([44, 22, 33, 66, 55, 11])
# 选择前三个最大元素,可以先全局排序
# 然后选择前三个
print(
    np.sort(arr)[-3:]
)  # [44 55 66]
# 使用 argsort,找到最大的三个元素的索引
# 然后通过 arr 进行筛选
print(
    arr[np.argsort(arr)[-3:]]
)  # [44 55 66]

另外 argsort 还可以实现一个特殊的需求,那就是排名统计,举个例子:[44, 22, 33, 66, 55, 11],如果按照从小到大排名,

  • 44 在 array 中排名第 4;
  • 22 在 array 中排名第 2;
  • 33 在 array 中排名第 3;
  • 66 在 array 中排名第 6;
  • 55 在 array 中排名第 5;
  • 11 在 array 中排名第 1;

所以需要实现一个功能,将 arr 传进去之后,返回 [4 2 3 6 5 1]。如果是你的话,你会怎么做呢?最容易想到的办法就是遍历:

lst = [44, 22, 33, 66, 55, 11]
# 先排序
sorted_lst = sorted(lst)
# 依次查看 lst 里面的每个元素,在排序后列表中的位置
# 但索引从 0 开始,因此需要再加上 1
rank = [sorted_lst.index(item) + 1
        for item in lst]
print(rank)  # [4, 2, 3, 6, 5, 1]

但是这段代码有缺陷,那就是 lst 里面的元素不能重复,这里我们假设不重复。如果用 argsort 该怎么做呢?

import numpy as np
arr = np.array([44, 22, 33, 66, 55, 11])
print(
    np.argsort(arr)
)  # [5 1 2 0 4 3]
print(
    np.argsort(np.argsort(arr))
)  # [3 1 2 5 4 0]
print(
    np.argsort(np.argsort(arr)) + 1
)  # [4 2 3 6 5 1]
# np.argsort(arr) 等价于 arr.argsort()

使用 Numpy 的话两次 argsort 即可完成,这个过程可能有点绕,为了方便理解,我们画一张图:

可以对照着图多理解一下。


np.argpartition



介绍 argpartiiton 之前,先来说一下 partition,它和 np.sort 类似,都是用于排序,但 partiton 是局部排序。

import numpy as np
arr = np.array([66, 15, 27, 33, 19, 13, 10])
"""
np.partition(arr, k) 返回一个新数组
并且索引为 k 的元素比它左边的元素都大,比它右边的元素都小
"""
print(
    np.partition(arr, 3)
)  # [15 13 10 19 27 33 66]
# 返回的数组中索引为 3 的元素是 19
# 比它左边的元素都大,比它右边的元素都小
# 至于左右两边的顺序则没有要求

当我们想实现 topK 的时候,这个方法非常适合。

虽然我们可以使用 np.sort 排序之后再通过切片截取, 但 sort 是全局排序。如果数组非常大, 我们只希望选择最小的 10 个元素,那么全局排序就有点浪费了。而最好的方式是通过 np.partition(arr, 9),然后将前 10 个元素截取出来进行排序即可,而无需对整个大数组进行排序。

结果是一样的,但是两者的效率如何呢?我们来对比一下:

我们看到效率差的不是一点点,原因就在于 partition 是一个时间复杂度为 O(N) 算法。因此在处理 topK 问题时,如果 K 远小于长度 N,那么使用 partition 是最佳选择。

另外,partition 的第二个参数还可以为负数:

import numpy as np
arr = np.array([66, 15, 27, 33, 19, 13, 10])
# 让数组中索引为 -2 的元素比它左边的元素都大,
# 比它右边的元素都小,显然这是用来筛选前 K 个最大的元素
print(
    np.partition(arr, -2)
)  # [15 13 10 19 27 33 66]
# 比如筛选前 3 个最大的元素
print(
    np.partition(arr, -3)[-3:]
)  # [27 33 66]

partition 函数的用法我们就知道了,而 argpartition 也和 argsort 类似,只不过返回的是对应的索引。可以自己编写代码测试一下。

以上就是 arg 开头的几个函数的用法。

相关文章
|
1月前
|
Python
NumPy 教程 之 NumPy 统计函数 9
NumPy提供了多种统计函数,如计算数组中的最小值、最大值、百分位数、标准差及方差等。其中,标准差是一种衡量数据平均值分散程度的指标,它是方差的算术平方根。例如,对于数组[1,2,3,4],其标准差可通过计算各值与均值2.5的差的平方的平均数的平方根得出,结果为1.1180339887498949。示例代码如下: ```python import numpy as np print(np.std([1,2,3,4])) ``` 运行输出即为:1.1180339887498949。
112 50
|
2月前
|
Python
NumPy 教程 之 NumPy 算术函数 1
本教程介绍NumPy中的基本算术函数,如加(add())、减(subtract())、乘(multiply())及除(divide())。示例展示了两个数组(一个3x3矩阵与一数组[10,10,10])间的运算。值得注意的是,参与运算的数组需有相同形状或可按照NumPy的广播规则进行扩展。此外Numpy还提供了许多其他的算术函数以满足复杂计算需求。
37 7
|
2月前
|
Python
NumPy 教程 之 NumPy 算术函数 2
NumPy 教程 之 NumPy 算术函数 2
30 3
|
2月前
|
Python
NumPy 教程 之 NumPy 数学函数 4
NumPy提供了丰富的数学函数,如三角函数、算术函数及复数处理等。本教程聚焦于舍入函数中的`numpy.ceil()`应用。该函数用于返回大于或等于输入值的最小整数(向上取整)。例如,对数组`[-1.7, 1.5, -0.2, 0.6, 10]`使用`np.ceil()`后,输出为`[-1., 2., -0., 1., 10.]`。
33 1
|
2月前
|
Python
NumPy 教程 之 NumPy 数学函数 3
本教程详细介绍了NumPy中的数学函数,特别是舍入函数`numpy.floor()`的使用方法。该函数可以返回小于或等于输入的最大整数,实现向下取整的功能。例如,对于数组`a = np.array([-1.7, 1.5, -0.2, 0.6, 10])`,应用`np.floor(a)`后,输出结果为`[-2., 1., -1., 0., 10.]`。这在处理包含浮点数的数据时非常有用。
28 0
|
1月前
|
Python
NumPy 教程 之 NumPy 统计函数 10
NumPy统计函数,包括查找数组中的最小值、最大值、百分位数、标准差和方差等。方差表示样本值与平均值之差的平方的平均数,而标准差则是方差的平方根。例如,`np.var([1,2,3,4])` 的方差为 1.25。
95 48
|
27天前
|
机器学习/深度学习 搜索推荐 算法
NumPy 教程 之 NumPy 排序、条件筛选函数 8
NumPy提供了多种排序方法,包括快速排序、归并排序及堆排序,各有不同的速度、最坏情况性能、工作空间和稳定性特点。此外,NumPy还提供了`numpy.extract()`函数,可以根据特定条件从数组中抽取元素。例如,在一个3x3数组中,通过定义条件选择偶数元素,并使用该函数提取这些元素。示例输出为:[0., 2., 4., 6., 8.]。
21 8
|
1月前
|
机器学习/深度学习 搜索推荐 算法
NumPy 教程 之 NumPy 排序、条件筛选函数 2
介绍NumPy` 中的排序方法与条件筛选函数。通过对比快速排序、归并排序及堆排序的速度、最坏情况性能、工作空间需求和稳定性,帮助读者选择合适的排序算法。此外,还深入讲解了 `numpy.argsort()` 的使用方法,并通过具体实例展示了如何利用该函数获取数组值从小到大的索引值,并据此重构原数组,使得其变为有序状态。对于学习 `NumPy` 排序功能来说,本教程提供了清晰且实用的指导。
22 7
|
28天前
|
机器学习/深度学习 搜索推荐 算法
NumPy 教程 之 NumPy 排序、条件筛选函数 5
NumPy中的排序方法及特性对比,包括快速排序、归并排序与堆排序的速度、最坏情况性能、工作空间及稳定性分析。并通过`numpy.argmax()`与`numpy.argmin()`函数演示了如何获取数组中最大值和最小值的索引,涵盖不同轴方向的操作,并提供了具体实例与输出结果,便于理解与实践。
19 4
|
2月前
|
Python
NumPy 教程 之 NumPy 统计函数 4
这段内容介绍了NumPy库中的统计函数,特别是`numpy.percentile()`函数的应用。该函数用于计算数组中的百分位数,即一个值之下所包含的观测值的百分比。通过实例展示了如何使用此函数来计算不同轴上的百分位数,并保持输出的维度不变。
30 5