快速排序
2022-11-12 15:50:33 # 算法 # 排序

1. 说在前面

快速排序使用了分治的基本算法思想。它的大致思路是,设置一个枢轴值作为基准,把数组中小于基准的放在左边,剩下的放在右边,然后只需递归地对小于基准的部分和大于等于基准的部分完成排序即可。

2. 经典的快速排序

算法

经典的快速排序就不再多说了,重点是它的partition过程,即选定一个枢轴以后,怎么把数组整理成小于枢轴的放左边,剩下的放右边呢?其实有很多方法可以做到这一点,这里介绍一种方法。其实我们只要把所有小于枢轴的元素挪在一起放左边,那么剩下的就自然到右边了。那么怎么挪呢?

我们设置一个指针less,它始终指向小于枢轴部分的最后一个元素,我们从左到右遍历数组,如果发现一个小于枢轴的元素,我们只需要把它和less的下一个交换即可,然后less++,如下图所示:

根据这个过程,我们可以写出经典快速排序的代码如下:

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class QuickSort {

public static void sort(int[] arr) {
int len = 0;
if(arr != null) len = arr.length;
if(len < 2) return;
sort(arr, 0, len - 1);
}

private static void sort(int[] arr, int lo, int hi) {
if(lo >= hi) return;
int p = partition(arr, lo, hi);
sort(arr, lo, p - 1);
sort(arr, p + 1, hi);
}

private static int partition(int[] arr, int lo, int hi) {
int pivot = arr[lo], i = lo, less = lo;// 这里默认选择第一个作为枢轴元素
while(++i <= hi)
// 如果发现当前元素小于枢轴,那么当前元素和less的下一个位置的元素交换
if(arr[i] < pivot) swap(arr, i, ++less);
swap(arr, less, lo);
return less;
}

private static void swap(int[] arr, int i, int j) {
if(i != j && arr[i] != arr[j]) {
int t = arr[i];
arr[i] = arr[j];
arr[j] = t;
}
}
}

优化

这个方法有什么问题呢?其实还是有点小问题的,

  1. 我们选择枢轴的方法是默认选择第一个元素的,这样当原数组原本就是从小到大有序的时候,这样会使得我们选择的枢轴打的很偏,这样,经过我们每一次的划分以后,大量的元素都被划分到一组了,划分失去了意义,时间复杂度也会变成$O(N^2)$;

  2. 我们把小于枢轴的元素划分为一组,剩下的元素(大于等于枢轴)划分为另一组,这样当数组中有很多元素和枢轴元素相同的时候,而小于枢轴的元素没几个,这样的话,使得小于枢轴的元素的那一部分很少,因此划分又几乎失去了意义了。

    举个极端的例子,当数组中所以的元素都是一个数的时候,那么每一次划分,都会使得左半部分没有数据,这样划分就完全没有意义了。

对于第一个问题,我们可以使用三数取中法,或者随机化的方法来选择枢轴元素,这样每一次都选择一个“很偏”的枢轴的概率就会变得很小了;对于第二个问题,我们可以使用另一种partition的方法,使得等于枢轴的元素尽量分散到两组中,就是下面介绍的二路快排。

3. 二路快速排序

算法

二路快排主要是partition上的改进,我们设置两个指针less和more,分别从数组的开头和末尾位置开始遍历,只要当前元素小于枢轴,那么less++,直到某一次遍历当前元素不小于枢轴;然后more指针开始走,只要当前元素大于枢轴,more--,直到某一次遍历当前元素不大于枢轴,然后交换less和more位置的元素,然后less++,more--,周而复始,直到less大于more为止。如下图所示:

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import java.util.Random;

public class QuickSort2Ways {

public static void sort(int[] arr) {
int len = 0;
if(arr != null) len = arr.length;
if(len < 2) return;
sort(arr, 0, len - 1);
}

private static void sort(int[] arr, int lo, int hi) {
if(lo >= hi) return;
int p = partition(arr, lo, hi);
sort(arr, lo, p - 1);
sort(arr, p + 1, hi);
}

private static int partition(int[] arr, int lo, int hi) {
Random r = new Random();

// 随机取[lo,hi]中的一个位置作为枢轴元素下标
int pivotIndex = r.nextInt(hi - lo + 1) + lo;
int pivot = arr[pivotIndex];

// 先把枢轴放在lo位置,便于操作
swap(arr, lo, pivotIndex);
int less = lo + 1, more = hi;
while(less <= more) {
// 只要less不越界,且less位置的元素小于枢轴,那么less向右
while(less <= more && arr[less] < pivot) ++less;

// 只要more不越界,且more位置的元素大于枢轴,那么more向左
while(less <= more && arr[more] > pivot) --more;

// 当二指针都停下且不越界的时候,交换二者的元素,less向右,more向左
if(less <= more) swap(arr, less++, more--);
}

// 把枢轴元素放回more位置
swap(arr, lo, more);
return more;
}

private static void swap(int[] arr, int i, int j) {
if(i != j && arr[i] != arr[j]) {
int t = arr[i];
arr[i] = arr[j];
arr[j] = t;
}
}
}

优化

二路快速排序中其实是考虑到等于枢轴的元素比较多的时候,把等于枢轴的元素尽量平均的分散到左右两组中,使得左右两组长度相差不要太大。我们知道,数组排好序后,相等的元素是相邻的,这些等于枢轴的元素在排好序以后一定是和枢轴元素相邻。而一次partition以后,我们只需要对枢轴左边和右边的部分进行排序,枢轴元素不参与后续的排序了,其实枢轴元素在这次partition以后已经在了最终的位置。

那么我们在partition的时候,能否把等于枢轴的元素和枢轴紧挨放置,这样等于枢轴的所有元素包括枢轴自己,都不需要参与后续的排序过程,只需要对小于,大于枢轴的部分进行排序即可。 这对于大部分都是相等的元素的数组排序有很大的性能提升。如下图所示:

下面介绍的三路快排,就能完成这样的操作。

4. 三路快速排序

算法

其实三路快速排序也是对partition的过程进行的优化,它把枢轴分成了三部分,小于枢轴部分,等于枢轴部分和大于枢轴的部分。然后只需对小于和大于枢轴的部分进行排序即可。那么问题是这个划分过程怎么实现?这个问题也叫做荷兰国旗问题,下面给出解决这个问题的方法。

设置两个指针less和more,分别代表小于枢轴部分的右边界和大于枢轴部分的左边界,我们从左到右遍历数组,当前位置为i,less初始化为枢轴开头的前一个位置,more初始化为末尾的后一个位置,i初始化为枢轴开头下标,只要i小于more,进行下面的过程:

  • 如果i位置的元素小于枢轴,那么i位置的元素和less的下一个位置的元素交换,less++,i++;
  • 如果i位置的元素等于枢轴,那么i++;
  • 如果i位置的元素大于枢轴,那么i位置的元素和more的前一个位置交换,more--,i不动

根据less和more的定义,只需要对开头到less和more到末尾这两部分排序即可。下图展示了partition的过程:

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import java.util.Random;

public class QuickSort3Ways {

private static int less, more;

public static void sort(int[] arr) {
int len = 0;
if(arr != null) len = arr.length;
if(len < 2) return;
sort(arr, 0, len - 1);
}

private static void sort(int[] arr, int lo, int hi) {
if(lo >= hi) return;
partition(arr, lo, hi);
sort(arr, lo, less);
sort(arr, more, hi);
}

private static void partition(int[] arr, int lo, int hi) {
Random r = new Random();

// 随机取[lo,hi]中的一个位置作为枢轴元素下标
int pivotIndex = r.nextInt(hi - lo + 1) + lo;

/*
当然这里也可以使用三数取中法获取枢轴下标
int pivotIndex;
if(hi - lo + 1 == 2) {
pivotIndex = lo;
}else {
int m = (lo + hi) >>> 1;
pivotIndex = arr[lo] >= arr[m] ?
(arr[m] >= arr[hi] ? m : (arr[lo] >= arr[hi] ? hi : lo))
:
(arr[lo] >= arr[hi] ? lo : (arr[m] >= arr[hi] ? hi : m));
}
*/
int pivot = arr[pivotIndex];
less = lo - 1;
more = hi + 1;
int i = lo;
while(i < more) {// 只要当前位置小于more位置,就继续

// 当前位置的数小于pivot,当前位置的数和less的下一个位置的数交换,less++, i++
if(arr[i] < pivot) swap(arr, ++less, i++);

// 当前位置的数大于pivot,当前位置的数和more的前一个位置的数交换,more--, i不动
else if(arr[i] > pivot) swap(arr, --more, i);

// 当前位置的数等于pivot,i++
else ++i;
}
}

private static void swap(int[] arr, int i, int j) {
if(i != j && arr[i] != arr[j]) {
int t = arr[i];
arr[i] = arr[j];
arr[j] = t;
}
}
}

5. 复杂度分析

5.1. 时间复杂度

假设数组长度为$N$:

最好情况下,当所有元素都一样的时候,三路快排一次扫描就可以完成排序,此时时间复杂度$O(N)$;

最坏情况下,partition每一次都选中当前区域的最大值或者最小值作为枢轴,使得左组或者右组的长度为0,那么此时就需要进行$N$次partition,而每一次partition至少需要扫描数组一遍,此时时间复杂度$O(N^2)$,即使这种情况出现的概率很低,但不可否认,也是存在的;

平均情况下,假设每一次partition后,小于枢轴的部分和大于枢轴的部分是等长的,假设时间复杂度$T(N)$,那么显然:$T(N) = 2*T(\frac{N}{2})+O(N)$,根据master公式,时间复杂度为$O(N\log_2{N})$。

5.2. 空间复杂度

最好情况下,当所有元素都一样的时候,三路快排一次扫描就可以完成排序,此时空间复杂度$O(1)$;

最坏情况下,partition每一次都选中当前区域的最大值或者最小值作为枢轴,使得左组或者右组的长度为0,那么此时就至少需要进行$N$次partition,因此空间复杂度$O(N)$;

平均情况下,假设每一次partition后,小于枢轴的部分和大于枢轴的部分是等长的,那么每一次partition以后都会使得当前排序的部分长度减少一半,因此经过$\log_2{N}$次partition以后,当前排序部分的长度是1,于是空间复杂度$O(\log_2{N})$;

为了便于理解空间复杂度,这里给出快速排序的迭代写法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import java.util.Stack;

public class QuickSort3WaysIteratively {

private static int less, more;

public static void sort(int[] arr) {
int len = 0;
if(arr != null) len = arr.length;
if(len < 2) return;
sort(arr, 0, len - 1);
}

private static void sort(int[] arr, int lo, int hi) {
Stack<Task> st = new Stack<Task>(); // 模拟系统栈
st.push(new Task(lo, hi)); // 压入初始任务
while(!st.isEmpty()) {
Task curTask = st.pop();
lo = curTask.start;
hi = curTask.end;
if(lo >= hi) continue;
partition(arr, lo, hi);
st.push(new Task(lo, less));// 压入子任务
st.push(new Task(more, hi));// 压入子任务
}
}

private static void partition(int[] arr, int lo, int hi) {
int pivotIndex;
if(hi - lo + 1 == 2) {
pivotIndex = lo;
}else {// 使用三数取中法获取枢轴下标
int m = (lo + hi) >>> 1;
pivotIndex = arr[lo] >= arr[m] ?
(arr[m] >= arr[hi] ? m : (arr[lo] >= arr[hi] ? hi : lo))
:
(arr[lo] >= arr[hi] ? lo : (arr[m] >= arr[hi] ? hi : m));
}
int pivot = arr[pivotIndex];
less = lo - 1;
more = hi + 1;
int i = lo;
while(i < more) {
if(arr[i] < pivot) swap(arr, ++less, i++);
else if(arr[i] > pivot) swap(arr, --more, i);
else ++i;
}
}

private static void swap(int[] arr, int i, int j) {
if(i != j && arr[i] != arr[j]) {
int t = arr[i];
arr[i] = arr[j];
arr[j] = t;
}
}

static class Task {
int start, end;
Task(int s, int e) {
start = s;
end = e;
}
}
}

6. 测试程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import java.util.Random;
import java.util.Arrays;

public class QuickSortTest {

public static void main(String[] args) {
int time = 100_0000, maxArraySize = 10;
while(time-- > 0) {
Random r = new Random();
int size = r.nextInt(maxArraySize) + 1;
int[] array1 = new int[size], array2 = new int[size], array3 = new int[size];
for(int i = 0; i < size; ++i) array1[i] = array2[i] = array3[i] = r.nextInt(201) - 100;
QuickSort3Ways.sort(array1);
Arrays.sort(array2);
for(int i = 0; i < size; ++i) {
if(array1[i] != array2[i]) {
System.out.println("Oops! wrong answer");
System.out.println(Arrays.toString(array1));
System.out.println(Arrays.toString(array2));
System.out.println("Origin array:");
System.out.println(Arrays.toString(array3));
return;
}
}
System.out.println("Testcase:" + time + " done!");
}
System.out.println("Done successfully!!");
}
}