通过 TOP K 问题,探究堆和优先队列的秘密

简介: 通过 TOP K 问题,探究堆和优先队列的秘密


楔子




在工作中我们经常会遇到这样一个需求,就是获取可迭代对象中的前 K 个最大或最小的元素。我们之前介绍过排序,所以一个最简单的办法就是先排序,排完了再选择前 K 个元素即可。

data = [3, 1, 2, 5, 4]
# 选择前 3 个最大的元素
# 和前 3 个最小的元素
sorted_data = sorted(data)
print(f"TOP 3 MAX:{sorted_data[-3:]}")
print(f"TOP 3 MIN:{sorted_data[: 3]}")
"""
TOP 3 MAX:[3, 4, 5]
TOP 3 MIN:[1, 2, 3]
"""
# 或者这么做
sorted_data = sorted(data, reverse=True)
print(f"TOP 3 MAX:{sorted_data[: 3]}")
print(f"TOP 3 MIN:{sorted_data[-3:]}")
"""
TOP 3 MAX:[5, 4, 3]
TOP 3 MIN:[3, 2, 1]
"""

显然这是一种解决办法,但如果列表的长度非常大,排序就会带来不小的开销。而且有时我们只要前几个元素即可,比如长度为 10000 的列表,我们只想要前三个最大或最小的元素,那么此时对整个列表进行排序显然会存在性能上的浪费。

sorted 函数的时间复杂度是 O(NlogN)

所以接下来我们要介绍一个模块叫 heapq,通过该模块我们能快速地获取前 K 个元素。

import random
import heapq
data = [random.randint(10, 10000)
        for _ in range(1000)]
# 获取前 3 个最大的元素
print(heapq.nlargest(3, data))
print(sorted(data, reverse=True)[: 3])
"""
[9972, 9972, 9966] 
[9972, 9972, 9966] 
"""
# 获取前 3 个最小的元素
print(heapq.nsmallest(3, data))
print(sorted(data)[: 3])
"""
[17, 21, 31] 
[17, 21, 31]
"""

得到的结果是一样的,但是性能差异如何呢?我们来测一下:


aac012e2e6b8fab93b81bc86c19d273f.png

可以看到性能差异还是蛮大的,并且列表长度越大,性能差距越明显。而根本原因就在于 sorted 会对列表进行全局排序,而 heapq 没有。


因此在获取前 K 个元素、并且 K 和列表长度差距比较大的时候,不妨使用 heapq 的 nsmallest 和 nlargest 函数,性能会有明显提升。

但如果 K 和列表长度相差不大,那么先 sorted 排序,再使用切片的方式会更好一些。

nsmallest 和 nlargest 这两个函数都接收 3 个参数,第一个参数表示要获取前多少个元素、第二个参数表示可迭代对象(一般是列表)、第三个参数是 key(和 sorted 函数里面的 key 含义相同)。

import random
import heapq
data = [{"number": random.randint(1, 10000)}
        for _ in range(1000)]
# data 内部都是字典
# 获取前 3 个 number 字段的值最大的字典
print(heapq.nlargest(3, data, key=lambda x: x["number"]))
"""
[{'number': 9991}, {'number': 9984}, {'number': 9970}]
"""
print(sorted(data, key=lambda x: x["number"], reverse=True)[: 3])
"""
[{'number': 9991}, {'number': 9984}, {'number': 9970}]
"""

特别提示,如果 K 为 1,那么使用内置函数 min 和 max 是最佳选择。

data = [{"number": random.randint(1, 10000)}
        for _ in range(1000)]
print(
    heapq.nlargest(1, data, key=lambda x: x["number"])
)  # [{'number': 9979}]
print(
    sorted(data, key=lambda x: x["number"])[-1]
)  # {'number': 9979}
print(
   max(data, key=lambda x: x["number"])
)  # {'number': 9979}

所以结论如下,当获取最大或最小元素的个数为 K,列表(可迭代对象)长度为 L 时:


  • K 等于 1,使用内置函数 max 或 min;
  • K 不等于 1、且远小于 L,使用 heapq 模块的 nlargest 或 nsmallest 函数;
  • K 和 L 差别不大,使用 sorted 先全局排序、然后再通过切片方式截取;


当然啦,以上都属于基础知识,比较简单。其实选择前 K 个元素就是我们常说的 TOP K 问题,如果只是单纯地想解决 TOP K 问题的话,上面已经给出了方案。这里我主要是想通过 TOP K 来引出一种数据结构,也就是


堆是一种非常高效的数据结构,我们可以用它实现优先队列,堆实现的优先队列在元素入队、出队的时间复杂度上均为 O(logN)。





什么是堆?




首先堆本身是一树,如果这树是二叉树,那么实现的堆就被称为二叉堆。当然除了二叉堆,还有三叉堆等等,只不过二叉堆是一种最主流的堆的实现方式。因此,堆(二叉堆)就是一满足一些特殊性质的二叉树,那么问题来了,它都满足哪些性质呢?


  • 堆是一棵完全二叉树;
  • 堆里的每一个节点都大于等于(或小于等于)它的孩子节点;
  • 如果每个节点都大于等于它的孩子节点,或者说每个节点都不大于它的父节点,那么这个堆就是大根堆;
  • 如果每个节点都小于等于它的孩子节点,或者说每个节点都不小于它的父节点,那么这个堆就是小根堆;

52982ce29b5143f7c1add66a85bbc6bf.png


注意:堆要求的是每个节点和其孩子节点之间要满足相应的大小关系,如果两个节点之间没有父子关系,那么它们谁大谁小无关紧要。比如图上的大根堆,第三层的最后一个节点是 13,可第四层的节点却都比它大,但它们之间没有父子关系,所以我们当前这个堆是成立的。


正是因为堆的这个性质,我们可以使用数组来表示堆,直接按照层序遍历的方式将每一层的元素放在数组中即可,比如:

[62, 41, 30, 28, 16, 22, 13, 19, 17, 15]


很明显,堆顶(数组索引为 0)的元素永远是值最大或最小的元素,如果构建的是大根堆,堆顶元素最大;构建的是小根堆,堆顶元素最小。


但是问题来了,如果我有一个节点,要如何找到它的父节点或者孩子节点呢?结论如下,假设当前节点所在的索引为 i:


  • 父节点的索引:(i - 1) / 2;
  • 左孩子节点的索引:2 * i + 1;
  • 右孩子节点的索引:2 * i + 2;


我们以索引为 3 这个元素(值为 28)为例,它父节点的索引是 (3 - 1) / 2 = 1,也就是 41 这个元素;左孩子节点的索引是 2 * 3 + 1 = 7,也就是 19 这个元素;右孩子节点的索引显然是 8,也就是 17 这个元素。可以对照上图,检验一下是否有误,或者你也可以创建一个更大的堆,自己测试一下,但前提必须是完全二叉树才具备这个性质。

显然通过这种方式,我们就不需要两个指针来维持节点之间的父子关系了,使用数组索引即可,并且通过索引定位元素的速度也会更快。





向堆中添加元素(Sift Up)




我们来看看如何往堆中添加元素,首先堆是一个完全二叉树,往堆中添加一个元素,从树的层面来看,就是往最后一层的最右端添加一个元素,如果最后一层已经满了,那么就新加一层。如果从数组的层面来看,就相当于 append 一个元素。


假设我们添加一个 52,那么堆的示意图就会变成如下这样:


a266c169ce7eaad48b8ff819c418b8ff.png

添加的过程非常简单,因为往堆里面添加一个节点,就是往数组里面 append 一个元素,但显然还没有结束。因为堆有两个性质,虽然我们添加元素之后仍然满足是一棵完全二叉树,但是不满足子节点都不大于它的父节点(这里我们构建的是大根堆),因为 52 明显大于它的父节点 16。


所以我们还要进行调整,将新添加的节点放到属于它的位置,具体过程也很简单:将该节点和它的父节点进行比较,如果比它的父节点大,那么就进行交换;交换之后再和它新的父节点进行比较,如果还大于新的父节点则继续交换,直到不大于为止。


所以从尾部添加的节点,一直向上浮动,直到找到属于它的位置,因此这个过程也被称为 Sift Up(上浮),具体示意图如下:


20bfcf31ade75aecad1527bb83822cca.png


当交换之后,发现不大于它的父节点,那么该节点就可以停下来了。可能有人问,它父节点之上的节点该怎么办?比如爷爷节点。答案是不需要关心,因为大根堆的性质就是每个子节点不大于父节点。所以当新添加的节点不大于它的父节点时,也更不可能大于父节点之上的爷爷节点。


下面我们就编写代码实现一下:

class BinaryHeap:
    """
    大根堆
    """
    def __init__(self):
        # 通过数组来模拟堆,为避免直接修改堆
        # 这个数组不对外暴露,而是专门提供一个接口
        self.__data = []
    def show_heap(self):
        return self.__data
    @staticmethod
    def get_parent(i: int):
        # 根据节点的索引找到其父节点的索引
        return (i - 1) // 2
    def heappush(self, item: int):
        # 往堆中添加一个节点,对于数组而言,直接 append 即可
        self.__data.append(item)
        # 但是还没有结束,添加完之后不满足堆的性质
        # 我们还要对堆进行调整,由 sift_up 函数负责,它接收一个索引
        # 表示对指定索引的节点进行上浮,显然这里是最后一个
        self.sift_up(len(self.__data) - 1)
    def sift_up(self, i: int):
        # 对指定索引位置的节点进行上浮
        while i > 0:
            parent = self.get_parent(i)
            # 当该元素不是根节点的时候,将其和父节点进行比较
            # 如果大于父节点,两者进行交换
            if self.__data[i] > self.__data[parent]:
                self.__data[i], self.__data[parent] = self.__data[parent], self.__data[i]
                # 交换之后该节点成为了父节点,然后将 parent 赋值为 i
                # 因为它还要继续作为新的子节点和新的父节点比较
                i = parent
            else:
                # 如果不大于父节点,说明该元素已经找到属于它的位置了
                # 直接将循环结束掉即可
                break
heap = BinaryHeap()
for item in [62, 41, 30, 28, 16, 22, 13, 19, 17, 15]:
    heap.heappush(item)
print(heap.show_heap())
"""
[62, 41, 30, 28, 16, 22, 13, 19, 17, 15]
                        62
            41                       30
      28          16           22          13
  19      17  15
"""
# 这个时候再添加一个元素 52
heap.heappush(52)
print(heap.show_heap())
"""
[62, 52, 30, 28, 41, 22, 13, 19, 17, 15, 16]
                        62
            52                       30
      28          41           22          13
  19      17  15      16
"""

可以看到结果是没有问题的,以上我们添加元素就成功了,下面我们再来看看如何从堆中取出元素。





从堆中取出元素(Sift Down)



正如添加节点从堆底添加,取出节点只能从堆顶取出(也就是只能取根节点),不能取其它位置的节点。


d533bdba1a2d89d24f9cd3843d97875e.png


但问题是,如果直接将堆顶的节点取走的话,就会形成两个独立的堆,两个堆的根节点分别是它的左右节点。于是我们还要手动将两个堆合并在一起,会比较麻烦,所以我们可以换个思路,将堆顶和堆底的元素进行交换。交换之后,弹出堆底的元素,这样就得到了最大值。


但该做法同时也破坏了堆的第二个性质,因为之前的堆底元素现在跑到了堆顶,肯定不满足父节点和子节点之间的大小关系,所以我们还要进行调整。


对于大根堆而言,将该节点和左右子节点中大的那一个进行比较,如果比子节点小,那么进行交换。交换之后再和它新的子节点进行比较,如果还小于新的子节点则继续交换,直到不小于为止。


所以从顶部的节点,一直向下沉,直到找到属于它的位置,因此这个过程也被称为 Sift Down(下沉),具体示意图如下:


8c4359fb16764e70c1102805ac8fc00e.png


注意:堆顶节点和堆底节点交换之后,就被弹出了,所以图中的 62 不再是堆节点,因此我们颜色刻意画的淡了一些。


下面完善一下之前的代码:

import random
class BinaryHeap:
    """
    大根堆
    """
    def __init__(self):
        # 通过数组来模拟堆,为避免直接修改堆
        # 这个数组不对外暴露,而是专门提供一个接口
        self.__data = []
    def show_heap(self):
        return self.__data
    @staticmethod
    def get_parent(i: int):
        # 根据节点的索引找到其父节点的索引
        return (i - 1) // 2
    @staticmethod
    def get_left_child(i: int):
        # 根据节点的索引找到左孩子节点的索引
        return 2 * i + 1
    @staticmethod
    def get_right_child(i: int):
        # 根据节点的索引找到右孩子节点的索引
        return 2 * i + 2
    def heappush(self, item: int):
        self.__data.append(item)
        self.sift_up(len(self.__data) - 1)
    def sift_up(self, i: int):
        while i > 0:
            parent = self.get_parent(i)
            if self.__data[i] > self.__data[parent]:
                self.__data[i], self.__data[parent] = self.__data[parent], self.__data[i]
                i = parent
            else:
                break
    def heappop(self):
        # 弹出堆顶元素
        if len(self.__data) == 0:
            raise ValueError("heap is empty")
        # 只需要将第一个元素和最后一个元素进行交换,然后返回即可
        self.__data[0], self.__data[-1] = self.__data[-1], self.__data[0]
        result = self.__data.pop()
        # 不过在返回之前,记得调整一下堆,由 sift_down 函数负责
        # 此函数接收一个索引,表示对指定节点的索引进行下沉
        # 显然这里是第一个
        self.sift_down(0)
        return result
    def sift_down(self, i: int):
        # 对索引为 i 的节点进行下沉,这里需要判断孩子节点是否存在的情况
        # 如果左孩子节点的索引越界,说明该节点已经是叶子节点了
        while self.get_left_child(i) < len(self.__data):
            left_child = self.get_left_child(i)
            right_child = self.get_right_child(i)
            # 获取子节点大的那一个,注意:需要考虑右节点是否存在的情况
            child = (right_child
                     if right_child < len(self.__data) and
                        self.__data[left_child] < self.__data[right_child]
                     else left_child)
            # 将该节点和孩子节点进行比较,如果比孩子节点小,那么交换位置
            # 继续和新的孩子节点进行比较
            if self.__data[i] < self.__data[child]:
                self.__data[i], self.__data[child] = self.__data[child], self.__data[i]
                i = child
            # 否则直接跳出循环
            else:
                break
heap = BinaryHeap()
data = [random.randint(1, 20) for _ in range(10)]
print(data)
"""
[3, 9, 3, 12, 12, 14, 5, 18, 20, 11]
"""
# 依次添加到堆中
for item in data:
    heap.heappush(item)
# 从堆中弹出,由于每次都会弹出最大值
# 所以得到的新列表是降序排序的
sorted_data = [heap.heappop() for _ in range(10)]
print(sorted_data)
"""
[20, 18, 14, 12, 12, 11, 9, 5, 3, 3]
"""

显然是没有问题的,因此我们这里就实现了一个堆排序,只不过这个堆排序还不太完美,不完美之处有两个地方:


  • 1. 默认是从大到小排序的,应该提供一个参数供外界选择究竟是从大到小还是从小到大;
  • 2. 这里开辟了一个额外的数组,合适的做法应该是接收一个数组,然后原地排序;


那么下面我们完善一下堆排序。

import random
def get_left_child(i: int):
    return 2 * i + 1
def get_right_child(i: int):
    return 2 * i + 2
def sift_down_large(data, i: int, length: int):
    # 大根堆下沉,但是参数多了一个 length,这是为啥呢?
    # 首先我们之前是将堆顶和堆底的元素交换之后,就将堆底的元素弹出去了
    # 以至于我们需要单独开辟一个数组去接收
    # 但很明显,我们这里要求原地排序,那么交换之后的元素在堆底不可以动
    # 因此每 sift_down 一次,length 要减去 1
    while get_left_child(i) < length:
        left_child = get_left_child(i)
        right_child = get_right_child(i)
        # 判断是否有右孩子,如果有右孩子
        # 那么选择值较大的那一个孩子节点
        child = (right_child
                 if right_child < length and
                    data[left_child] < data[right_child]
                 else left_child)
        # 如果比孩子节点的值小,那么两者进行交换
        # 因为大根堆要求父节点不小于子节点
        if data[i] < data[child]:
            data[i], data[child] = data[child], data[i]
            i = child
        else:
            break
def sift_down_small(data, i: int, length: int):
    # 小根堆下沉
    while get_left_child(i) < length:
        left_child = get_left_child(i)
        right_child = get_right_child(i)
        # 判断是否有右孩子,如果有右孩子
        # 那么选择值较小的那一个孩子节点
        child = (right_child
                 if right_child < length and
                    data[left_child] > data[right_child]
                 else left_child)
        # 如果比孩子节点的值大,那么两者进行交换
        # 因为大根堆要求父节点不大于子节点
        if data[i] > data[child]:
            data[i], data[child] = data[child], data[i]
            i = child
        else:
            break
def heapify_large(data):
    # 将一个数组整理成大根堆的形状
    # 从最后一个非叶子节点进行 sift_down 即可
    for i in range((len(data) - 1) >> 1, -1, -1):
        sift_down_large(data, i, len(data))
def heapify_small(data):
    # 将一个数组整理成小根堆的形状
    for i in range((len(data) - 1) >> 1, -1, -1):
        sift_down_small(data, i, len(data))
def heap_sort(data, reverse=False):
    # 堆排序
    # 首先将其整理成堆的形状
    if reverse:
        heapify_small(data)
    else:
        heapify_large(data)
    # i 从最后一个元素开始
    for i in range(len(data) - 1, -1, -1):
        # 交换完之后的元素就不可以动了
        data[0], data[i] = data[i], data[0]
        # 并且也不能再参与后续的 sift_down
        # 因此依旧调整堆,但是范围变了
        # 比如第一次交换,那么最后一个元素为最大值
        # sift_down 的时候,整个范围就是 [0: len(data) - 1]
        # 同理第二次 sift_down 的时候,范围就是 [0: len(data) - 2]
        if reverse:
            sift_down_small(data, 0, i)
        else:
            sift_down_large(data, 0, i)
data = [random.randint(1, 20) for _ in range(10)]
print(data)
"""
[17, 16, 10, 3, 13, 15, 11, 9, 12, 9]
"""
heap_sort(data)
print(data)
"""
[3, 9, 9, 10, 11, 12, 13, 15, 16, 17]
"""
data = [random.randint(1, 20) for _ in range(10)]
print(data)
"""
[3, 1, 14, 20, 1, 10, 7, 8, 3, 15]
"""
heap_sort(data, reverse=True)
print(data)
"""
[20, 15, 14, 10, 8, 7, 3, 3, 1, 1]
"""

以上我们就实现了堆排序,那么问题来了,你觉得 heapq 模块里的 nlargest 和 nsmallest 是怎么实现的呢?


假设我们要选取 k 个最小的元素,那么首先我们可以从数组中截取前 k 个元素,构建一个大根堆。然后从第 k + 1 个元素开始遍历数组,如果当前元素大于等于堆顶元素,那么它肯定就不是前 k 小的元素,如果当前元素小于堆顶的元素,那么两者进行交换,然后进行一次 Sift Down 操作。当数组遍历完毕之后,堆中的 k 个元素就是最小的前 k 个元素。同理,如果想选择前 k 个最大的元素,那么就构建一个小根堆。

或者将整个数组构建成一个堆,然后heappop k 次即可,这样也能选择前 k 个元素。




优先队列




其实在排序的时候,堆排序不是效率最高的排序,它比三路快排要慢一些。但是堆存在的目的绝不仅仅是为了排序,由于它可以动态添加元素、删除元素,并且时间复杂度都为 O(logN) 级别,所以堆的强大之处就在于非常适合实现优先队列。


事实上 heapq 也已经为我们提供了堆的相关操作:

"""
heapq.heapify(data)
    将数组 data 整理成堆的形状,只支持小根堆
heapq.heappush(data, item)
    向堆中添加元素,并维护堆的形状
    要求 data 已经是一个小根堆
    
heapq.heappop(data, item)
    从堆中弹出元素,并维护堆的形状
    要求 data 已经是一个小根堆     
"""

而 Python 的优先队列,底层就是借助于 heapq 实现的,我们看一下:


bc762daeec276426e5eaddacc768935a.png


里面的 item 是一个元组,第一个元素是优先级(值越小、优先级越高),第二个元素是具体的数据,这就是优先队列,是不是比你想象中的要简单许多呢?




小结




堆是一种非常高效的数据结构,它可以动态地添加、删除元素,并且时间复杂度均为 O(logN) 级别。这个特性就决定了它非常适合实现优先队列,维护一个堆,在往堆中添加元素的时候,只需要加一个优先级即可,也就是将优先级和数据组合成一个元组添加到堆中。如果构建的是小根堆,值越小、优先级越高;构建的是大根堆,值越大,优先级越高。


另外我们说,当获取最大值或最小值时,推荐使用内置函数 max 和 min。但如果数组 data 一直在动态变化,并且要随时获取里面的最大值或最小值,那么相比使用内置函数 max、min,更好的做法是将 data 维护成一个堆。然后添加元素使用 heappush,获取元素直接 data[0] 即可。因为这整体是一个 O(logN) 的操作,而是 min、max 是一个 O(N) 的操作。


最后,堆可以用来实现排序,效率也很高,但相比三路快排还差了那么一点。但堆存在的目的不在于排序,而在于它的动态性。优先队列就不必说了,还有 TOP K,三路快排和堆都可以实现 TOP K,但前者要求数据必须一次性全部给出,而堆则没有这个要求,换句话说堆可以满足对流式数据的处理。


比如 1T 的文件,一行就是一串数字,如果想在 16G 内存的机器上查找最大的 100 个数字,用快排是无法实现的,因为无法将文件一次性加载到内存中。


但堆可以实现这个需求,先读取 100 行维护一个小根堆,然后从 101 行继续读取,依次和堆顶进行比较。如果小于堆顶元素,那么它一定不是前 100 个最大的数字;如果大于堆顶元素,那么就替换掉,然后 sift_down,维护堆的形状。这样总有一刻,能够选出最大的 100 个数字。


所以当数组 data 不断地变化时,将其维护成一个堆,然后通过 heappush 添加元素、heappop 弹出堆顶元素、data[0] 获取堆顶元素,往往是最佳选择。并且添加和弹出都是 logN 级别的时间复杂度,也正是这个特性,它适合优先队列以及流式数据(数据无法一次性全部给出)的处理。

相关文章
|
5月前
|
存储 算法
算法编程(九):最小栈
算法编程(九):最小栈
32 0
|
12月前
堆的经典TOP K问题
堆的经典TOP K问题
|
5月前
|
存储 机器学习/深度学习 算法
数据结构与算法⑬(第四章_中_续二)堆解决Topk问题+堆的概念选择题
数据结构与算法⑬(第四章_中_续二)堆解决Topk问题+堆的概念选择题
46 3
|
5月前
|
存储 算法 C++
数据结构第十一弹---堆
数据结构第十一弹---堆
|
5月前
|
算法 搜索推荐
数据结构第十二弹---堆的应用
数据结构第十二弹---堆的应用
|
存储
【数据结构】优先级队列(堆)重点知识汇总(附有代码)
【数据结构】优先级队列(堆)重点知识汇总(附有代码)
109 0
|
存储 算法
【数据结构】--- 博主拍了拍你并向你扔了一“堆”二叉树(堆的概念+结构+代码实现)
数据结构学习第十三弹——二叉树——堆的概念+结构+代码实现
|
算法 Java 编译器
Java数据结构与算法分析(四)栈
栈又叫堆栈,是一种运算受限制的线性表,限定只能在一端进行插入和删除操作,该端称为栈顶(Top),相对的另一端叫栈底(Bottom)。
89 0
|
存储
什么是堆,如何实现?(附堆排序,TOP-K问题)
什么是堆,如何实现?(附堆排序,TOP-K问题)
114 0
什么是堆,如何实现?(附堆排序,TOP-K问题)
|
算法 JavaScript 前端开发
日拱算法:用两个栈实现队列&包含min函数的栈
本篇带来【剑指offer】的两道初级算法题:冲~~