查找数组中第 k 大的元素:快速选择算法

字数: 550

快速选择算法基于快速排序的分治策略,通过递归分治的方式降低时间复杂度。

题目

力扣 :215. 数组中的第K个最大元素

题解

考虑先排序后查表的方式,但是这样过于低效。如果我们将数组取一个随机值以此按大小分成三部分可以发现:
如果大于的部分小于 k 的情况下大于的部分和等于的部分 >= k,此时该随机值就是目标。

对此,可以先分门别类:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
    std::vector<int> left;
    std::vector<int> mid;
    std::vector<int> right;

    int pivot = list[rand() % list.size()];

    for (auto iter : list)
    {
        if (iter > pivot)
            left.emplace_back(iter);
        else if (iter == pivot)
            mid.emplace_back(iter);
        else
            right.emplace_back(iter);
    }

再进行分治:

1
2
3
4
5
6
    if (left.size() >= k)
        return findKthLargest(left, k);
    else if (left.size() + mid.size() >= k)
        return pivot;
    else
        return findKthLargest(right, k - left.size() - mid.size());

left 部分是大于 pivot 的,当该部分大于 k 则可以说我们目标值在 left 部分,还是取第 k 大的数。
而当 len(left) + len(mid) 小于 pivot ,可以得出目标值在 right 部分,此时取 right 部分但是 k 要减去 len + mid 部分的长度。

c++ 代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
int findKthLargest(std::vector<int> &list, int k)
{
    std::vector<int> left;
    std::vector<int> mid;
    std::vector<int> right;

    int pivot = list[rand() % list.size()];
    // 记得在 main 函数 srand(time(0)) 初始化随机数

    for (auto iter : list)
    {
        if (iter > pivot)
            left.emplace_back(iter);
        else if (iter == pivot)
            mid.emplace_back(iter);
        else
            right.emplace_back(iter);
    }
    if (left.size() >= k)
        return findKthLargest(left, k);
    else if (left.size() + mid.size() >= k)
        return pivot;
    else
        return findKthLargest(right, k - left.size() - mid.size());
}

python 代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
import random

def findKthLargest(nums:list, k:int):
    pivot = random.choice(nums)
    
    left = [x for x in nums if x > pivot]
    mid = [x for x in nums if x == pivot]
    right = [x for x in nums if x < pivot]

    if k <= len(left):
        return findKthLargest(left, k)
    elif k <= len(left) + len(mid):
        return pivot
    else:
        return findKthLargest(right, k - len(left) - len(mid))

arr = [4, 5, 9, 12, 9, 22, 45, 7]

print(findKthLargest(arr, 4))  # 9

参考资料

  1. 数组中的第k个最大元素 | 编程面试必刷题