归并排序 Merge sort
归并排序是 分治算法 的经典实现. 它将数组分成较小的数组并排序, 然后再将它们合并在一起, 得到的数组就是有序的了.
归并排序的步骤
默认实现的递归排序是自顶向下(top-down merge sort)的, 即将整个数组递归分隔.
- 分隔 divide: 将数组递归分成两部分子数组, 直到每部分只剩下一个元素为止
- 攻克 conquer: 使用分治算法排序每个子数组
- 合并 merge: 将排序好的子数组有序合并在一起
第一阶段: 将数组递归分隔 (partition) 成左右两部分:
第二阶段, 将子数组合并在一起:
归并排序的实现
#![allow(unused)] fn main() { /// 对于元素个数为 `N` 的数组, 自顶向下的归并排序 (top-down merge sort) /// 最多使用 `N log(N)` 次比较以及 `6N log(N)` 次元素访问操作. #[inline] pub fn topdown_merge_sort<T>(arr: &mut [T]) where T: PartialOrd + Clone, { if arr.is_empty() { return; } sort(arr, 0, arr.len() - 1); } /// 排序 `arr[low..=high]` 部分. fn sort<T>(arr: &mut [T], low: usize, high: usize) where T: PartialOrd + Clone, { if low >= high { return; } let middle = low + (high - low) / 2; // 递归排序左侧部分数组 sort(arr, low, middle); // 递归排序右侧部分数组 sort(arr, middle + 1, high); // 合并左右两侧部分数组 if arr[middle] > arr[middle + 1] { merge(arr, low, middle, high); } } /// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组. /// /// 它不是原地合并. #[allow(clippy::needless_range_loop)] fn merge<T>(arr: &mut [T], low: usize, middle: usize, high: usize) where T: PartialOrd + Clone, { // 辅助数组, 先将数组复制一份. let aux = arr[low..=high].to_vec(); // 再合并回原数组. let mut i = low; let mut j = middle + 1; for k in low..=high { if i > middle { arr[k] = aux[j - low].clone(); j += 1; } else if j > high { arr[k] = aux[i - low].clone(); i += 1; } else if aux[j - low] < aux[i - low] { arr[k] = aux[j - low].clone(); j += 1; } else { }
归并排序的特点
- 归并排序的时间复杂度是
O(n log(n))
, 空间复杂度是O(N)
元素较少时, 使用插入排序
在排序阶段, 如果数组元素较少时仍然使用递归的归并排序的话, 并不划算, 因为会有大量的递归分支被调用,
还可能导致栈溢出. 为此我们设置一个常量, CUTOFF=24
, 当数组元素个数小于它时, 直接使用插入排序.
另外, 我们还在递归调用之前, 创建了辅助数组 aux
, 这样就可以在合并时重用这个数组, 以减少内存的分配.
#![allow(unused)] fn main() { i += 1; } } } /// 对于元素数较少的数组, 使用插入排序 pub fn insertion_merge_sort<T>(arr: &mut [T]) where T: PartialOrd + Clone, { if arr.is_empty() { return; } let cutoff: usize = 24; let mut aux = arr.to_vec(); sort_cutoff_with_insertion(arr, 0, arr.len() - 1, cutoff, &mut aux); } /// 排序 `arr[low..=high]` 部分, 如果元数较少, 就使用插入排序. fn sort_cutoff_with_insertion<T>( arr: &mut [T], low: usize, high: usize, cutoff: usize, aux: &mut Vec<T>, ) where T: PartialOrd + Clone, { if low >= high { return; } if high - low <= cutoff { insertion_sort(&mut arr[low..=high]); return; } let middle = low + (high - low) / 2; // 递归排序左侧部分数组 sort_cutoff_with_insertion(arr, low, middle, cutoff, aux); // 递归排序右侧部分数组 sort_cutoff_with_insertion(arr, middle + 1, high, cutoff, aux); // 合并左右两侧部分数组 if arr[middle] > arr[middle + 1] { merge_with_aux(arr, low, middle, high, aux); } } /// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组. /// /// 它不是原地合并. #[allow(clippy::needless_range_loop)] fn merge_with_aux<T>(arr: &mut [T], low: usize, middle: usize, high: usize, aux: &mut [T]) where T: PartialOrd + Clone, { // 辅助数组, 先将数组复制一份. for index in low..=high { aux[index].clone_from(&arr[index]); } // 再合并回原数组. let mut i = low; let mut j = middle + 1; for k in low..=high { if i > middle { arr[k] = aux[j].clone(); j += 1; /// 其思路是, 先将前 i 个元素调整为增序的, 随着 i 从 0 增大到 n, 整个序列就变得是增序了. pub fn insertion_sort<T>(arr: &mut [T]) where T: PartialOrd, { let len = arr.len(); for i in 1..len { for j in (1..=i).rev() { if arr[j - 1] > arr[j] { arr.swap(j - 1, j); } else { break; } } } } }
元素较少时, 使用希尔排序
这个方法是基于以上方法, 用希尔排序来代替插入排序, 可以得到更好的性能. 而且 CUTOFF
值也可以更大一些.
经过几轮测试发现, 对于希尔排序来说, CUTOFF
的取值位于 [64..92]
之间时, 性能较好.
#![allow(unused)] fn main() { arr[k] = aux[i].clone(); i += 1; } else if aux[j] < aux[i] { arr[k] = aux[j].clone(); j += 1; } else { arr[k] = aux[i].clone(); i += 1; } } } /// 对于元素数较少的数组, 使用希尔排序 pub fn shell_merge_sort<T>(arr: &mut [T]) where T: PartialOrd + Clone, { if arr.is_empty() { return; } let cutoff: usize = 72; let mut aux = arr.to_vec(); sort_cutoff_with_shell(arr, 0, arr.len() - 1, cutoff, &mut aux); } /// 排序 `arr[low..=high]` 部分, 如果元数较少, 就使用希尔排序. fn sort_cutoff_with_shell<T>( arr: &mut [T], low: usize, high: usize, cutoff: usize, aux: &mut Vec<T>, ) where T: PartialOrd + Clone, { let middle = low + (high - low) / 2; // 递归排序左侧部分数组 sort_cutoff_with_insertion(arr, low, middle, cutoff, aux); // 递归排序右侧部分数组 sort_cutoff_with_insertion(arr, middle + 1, high, cutoff, aux); // 合并左右两侧部分数组 if arr[middle] > arr[middle + 1] { merge_with_aux(arr, low, middle, high, aux); } } /// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组. /// /// 它不是原地合并. #[allow(clippy::needless_range_loop)] fn merge_with_aux<T>(arr: &mut [T], low: usize, middle: usize, high: usize, aux: &mut [T]) where T: PartialOrd + Clone, { // 辅助数组, 先将数组复制一份. for index in low..=high { aux[index].clone_from(&arr[index]); } // 再合并回原数组. let mut i = low; let mut j = middle + 1; for k in low..=high { if i > middle { arr[k] = aux[j].clone(); j += 1; /// Shell sort is a simple extension to insertion sort that allows exchanging /// elements that far apart. /// /// It produces partially sorted array (h-sorted array). pub fn shell_sort<T>(arr: &mut [T]) where T: PartialOrd, { const FACTOR: usize = 3; let len = arr.len(); // 计算第一个 gap 的值, 大概是 len/3 let mut h = 1; while h < len / FACTOR { h = FACTOR * h + 1; } while h >= 1 { // 使用插入排序, 将 `arr[0..h]` 排序好 for i in h..len { let mut j = i; while j >= h && arr[j - h] > arr[j] { arr.swap(j - h, j); j -= h; } } h /= FACTOR; } } }
迭代形式实现的归并排序
迭代形式的归并排序, 又称为自下而上的归并排序 (bottom-up merge sort). 它的步骤如下:
- 将连续的 2 个元素比较并合并在一起
- 将连续的 4 个元素比较并合并在一起
- 重复以上过程, 直到所有元素合并在一起
下面的流程图展示了一个简单的操作示例:
对应的代码实现如下:
#![allow(unused)] fn main() { // 递归排序左侧部分数组 sort_cutoff_with_shell(arr, low, middle, cutoff, aux); // 递归排序右侧部分数组 sort_cutoff_with_shell(arr, middle + 1, high, cutoff, aux); // 合并左右两侧部分数组 if arr[middle] > arr[middle + 1] { merge_with_aux(arr, low, middle, high, aux); } } /// 迭代形式的归并排序, 自底向上 bottom-up merge sort pub fn bottom_up_merge_sort<T>(arr: &mut [T]) where T: PartialOrd + Clone, { let len = arr.len(); if len < 2 { return; } let mut aux = arr.to_vec(); // 开始排序的数组大小, 从 1 到 len / 2 // current_size 的取值是 1, 2, 4, 8, ... let mut current_size = 1; while current_size < len { // 归并排序的数组左侧索引 let mut left_start = 0; // 子数组的起始点不同, 这样就可以遍历整个数组. // left_start 的取值是 0, 2 * current_size, 4 * current_size, ... // right_end 的取值是 2 * current_size, 4 * current_size, 6 * current_size, ... while left_start < len - 1 { let middle = (left_start + current_size - 1).min(len - 1); sort_cutoff_with_insertion(arr, middle + 1, high, cutoff, aux); // 合并左右两侧部分数组 if arr[middle] > arr[middle + 1] { merge_with_aux(arr, low, middle, high, aux); } } /// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组. /// /// 它不是原地合并. #[allow(clippy::needless_range_loop)] fn merge_with_aux<T>(arr: &mut [T], low: usize, middle: usize, high: usize, aux: &mut [T]) where T: PartialOrd + Clone, { // 辅助数组, 先将数组复制一份. for index in low..=high { aux[index].clone_from(&arr[index]); } // 再合并回原数组. let mut i = low; let mut j = middle + 1; for k in low..=high { if i > middle { arr[k] = aux[j].clone(); j += 1; } else if j > high { arr[k] = aux[i].clone(); i += 1; } else if aux[j] < aux[i] { arr[k] = aux[j].clone(); }
三路归并排序 3-way merge sort
默认实现的归并排序, 是将数组分成左右两部分分别排序. 三路归并排序, 是将数组分成左中右三部分分别排序.
#![allow(unused)] fn main() { /// 三路归并排序 pub fn three_way_merge_sort<T>(arr: &mut [T]) where T: PartialOrd + Clone, { if arr.is_empty() { return; } let mut aux = arr.to_vec(); three_way_sort(arr, 0, arr.len() - 1, &mut aux); } /// 三路排序 `arr[low..=high]` fn three_way_sort<T>(arr: &mut [T], low: usize, high: usize, aux: &mut Vec<T>) where T: PartialOrd + Clone, { // 如果数组长度小于2, 就返回. if low + 1 > high { return; } // 将数组分成三部分 let middle1 = low + (high - low) / 3; let middle2 = low + 2 * ((high - low) / 3); // 递归排序各部分数组 three_way_sort(arr, low, middle1, aux); three_way_sort(arr, middle1 + 1, middle2, aux); three_way_sort(arr, middle2 + 1, high, aux); // 合并三部分数组 three_way_merge(arr, low, middle1, middle2, high, aux); } /// 合并 `arr[low..=middle1]`, `arr[middle1+1..=middle2]` 以及 `arr[middle2+1..=high]` 三个子数组. /// /// 它不是原地合并. #[allow(clippy::needless_range_loop)] fn three_way_merge<T>( arr: &mut [T], low: usize, middle1: usize, middle2: usize, high: usize, aux: &mut [T], ) where T: PartialOrd + Clone, { // 辅助数组, 先将数组复制一份. for index in low..=high { aux[index].clone_from(&arr[index]); } // 再合并回原数组. let mut i = low; let mut j = middle1 + 1; let mut k = middle2 + 1; let mut l = low; // 首先合并较小的子数组 while i <= middle1 && j <= middle2 && k <= high { let curr_index = if aux[i] < aux[j] && aux[i] < aux[k] { &mut i } else if aux[j] < aux[k] { &mut j } else { &mut k }; arr[l].clone_from(&aux[*curr_index]); *curr_index += 1; l += 1; } // 然后合并剩余部分的子数组 while i <= middle1 && j <= middle2 { let curr_index = if aux[i] < aux[j] { &mut i } else { &mut j }; arr[l].clone_from(&aux[*curr_index]); *curr_index += 1; l += 1; } while j <= middle2 && k <= high { let curr_index = if aux[j] < aux[k] { &mut j } else { &mut k }; arr[l].clone_from(&aux[*curr_index]); *curr_index += 1; l += 1; } while i <= middle1 && k <= high { let curr_index = if aux[i] < aux[k] { &mut i } else { &mut k }; arr[l].clone_from(&aux[*curr_index]); *curr_index += 1; l += 1; } while i <= middle1 { arr[l].clone_from(&aux[i]); i += 1; l += 1; } while j <= middle2 { arr[l].clone_from(&aux[j]); j += 1; l += 1; } while k <= high { arr[l].clone_from(&aux[k]); k += 1; l += 1; } } }
三路归并排序的特点:
- 时间复杂度是
O(n log_3(n))
, 空间复杂度是O(n)
- 但因为在
merge_xx()
函数中引入了更多的比较操作, 其性能可能更差
原地归并排序
原地归并排序, 是替代了辅助数组, 它使用类似插入排序的方式, 将后面较大的元素交换到前面合适的位置. 尽管省去了辅助数组, 但是因为移动元素的次数显著境多了, 其性能表现并不好.
下面的流程图展示了一个原地归并排序的示例:
#![allow(unused)] fn main() { /// 原地归并排序 /// /// 尽管它不需要辅助数组, 但它的性能差得多, 时间复杂度是 `O(N^2 Log(N))`, 而默认实现的归并排序的 /// 时间复杂度是 `O(N Log(N))`. pub fn in_place_merge_sort<T>(arr: &mut [T]) where T: PartialOrd, { if arr.is_empty() { return; } sort_in_place(arr, 0, arr.len() - 1); } /// 原地排序 `arr[low..=high]` fn sort_in_place<T>(arr: &mut [T], low: usize, high: usize) where T: PartialOrd, { if low >= high { return; } let middle = low + (high - low) / 2; sort_in_place(arr, low, middle); sort_in_place(arr, middle + 1, high); if arr[middle] > arr[middle + 1] { merge_in_place(arr, low, middle, high); } } /// 原地合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组. fn merge_in_place<T>(arr: &mut [T], mut low: usize, mut middle: usize, high: usize) where T: PartialOrd, { let mut low2 = middle + 1; debug_assert!(arr[middle] > arr[low2]); while low <= middle && low2 <= high { if arr[low] <= arr[low2] { low += 1; } else { // 将所有元素右移, 并将 arr[low2] 插入到 arr[low] 所在位置. 这一步很慢. for index in (low..low2).rev() { arr.swap(index, index + 1); } // 更新所有的索引 low += 1; middle += 1; low2 += 1; } } } }
原地归并排序的特点:
- 时间复杂度度是
O(N^2 Log(N))
, 空间复杂度是O(1)
- c++ 的标准库里有实现类似的算法, 参考 inplace_merge
优化原地归并排序
上面的原地归并排序, 每次只移动一个元素间隔. 类似于希尔排序, 我们可以增大移动元素的间隔 (gap), 来减少 移动元素的次数.
#![allow(unused)] fn main() { /// 对原地归并排序的优化 /// /// 它不需要辅助数组, 它参考了希尔排序, 通过调整元素间隔 gap 减少元素移动次数. pub fn in_place_shell_merge_sort<T>(arr: &mut [T]) where T: PartialOrd, { if arr.is_empty() { return; } sort_in_place_with_shell(arr, 0, arr.len() - 1); } /// 原地排序 `arr[low..=high]` fn sort_in_place_with_shell<T>(arr: &mut [T], low: usize, high: usize) where T: PartialOrd, { if low >= high { return; } let middle = low + (high - low) / 2; sort_in_place_with_shell(arr, low, middle); sort_in_place_with_shell(arr, middle + 1, high); merge_in_place_with_shell(arr, low, high); } /// 使用希尔排序的方式原地合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组. /// /// 时间复杂度 `O(N Log(N))`, 空间复杂度 `O(1)` fn merge_in_place_with_shell<T>(arr: &mut [T], low: usize, high: usize) where T: PartialOrd, { #[must_use] #[inline] const fn next_gap(gap: usize) -> usize { const FACTOR: usize = 2; if gap == 1 { 0 } else { gap.div_ceil(FACTOR) } } let len = high - low + 1; let mut gap = next_gap(len); while gap > 0 { for i in low..=(high - gap) { let j = i + gap; // 每次间隔多个元素进行比较和交换. if arr[i] > arr[j] { arr.swap(i, j); } } gap = next_gap(gap); } } }
- 时间复杂度度是
O(n log(n) log(n))
, 空间复杂度是O(1)