Quick Sort and Quick Select

Quick sort and quick select are algorithms based on the partition algorithm. The partition algorithm produces two sublists ($\mathcal{A},\mathcal{B}$) and one pivot value ($\gamma$). Every element in $\mathcal{A}$ is less than or equal to $\gamma$. Every element in $\mathcal{B}$ is greater than $\gamma$. Note that $\mathcal{A}$ may contain elements equal to $\gamma$ but not $\gamma$ itself, and either $\mathcal{A}$ or $\mathcal{B}$ may be $\emptyset$. As a result, the partition function places $\gamma$ as the maximum element of $\mathcal{A}$. Quick sort uses this property to sort the list. After partitioning, $\mathcal{A}$ and $\mathcal{B}$ are sorted relative to each other by $\gamma$. Therefore, quick sort simply recurses on $\mathcal{A}$ and $\mathcal{B}$. Quick select does a similar thing. Quick select is an algorithm to find the $n$-th element in the list. If the rank of $\gamma$ equals the target rank, we are done. If the rank of $\gamma$ is less than the target, adjust the target rank and recurse into $\mathcal{B}$. If the rank of $\gamma$ is greater than the target, adjust the target rank and recurse into $\mathcal{A}$. It works like a partial quick sort — it sorts just enough to locate the desired element.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def partition(lst, lo, hi):
    # Partition lst[lo:hi] around the midpoint element.
    # Returns pivot index p such that lst[lo:p] <= lst[p] < lst[p+1:hi].
    mid = (lo + hi) // 2
    lst[lo], lst[mid] = lst[mid], lst[lo]
    pivot = lst[lo]
    l, r = lo + 1, hi - 1
    while True:
        while l <= r and lst[l] <= pivot:
            l += 1
        while l <= r and lst[r] > pivot:
            r -= 1
        if l >= r:
            break
        lst[l], lst[r] = lst[r], lst[l]
        l += 1
        r -= 1
    lst[lo], lst[r] = lst[r], lst[lo]
    return r


def qsort(lst, lo=0, hi=None):
    if hi is None:
        hi = len(lst)
    if hi - lo <= 1:
        return
    mid = partition(lst, lo, hi)
    qsort(lst, lo, mid)
    qsort(lst, mid + 1, hi)


def qselect(lst, idx):
    if idx >= len(lst):
        return -1

    def _select(lo, hi):
        if hi - lo <= 1:
            return
        mid = partition(lst, lo, hi)
        if idx < mid:
            _select(lo, mid)
        elif idx > mid:
            _select(mid + 1, hi)

    _select(0, len(lst))
    return lst[idx]


# Test
import random

if __name__ == "__main__":
    for _ in range(1000):
        lst = [random.randint(0, 100) for _ in range(1000)]
        target_idx = random.randint(0, len(lst) - 1)
        lst2 = list(lst)

        qsort(lst2)
        assert sorted(lst) == lst2, "qsort failed"
        assert qselect(lst, target_idx) == lst2[target_idx], "qselect failed"

    print("All tests passed!")