归并排序 Merge sort

归并排序是 分治算法 的经典实现. 它将数组分成较小的数组并排序, 然后再将它们合并在一起, 得到的数组就是有序的了.

归并排序的步骤

默认实现的递归排序是自顶向下(top-down merge sort)的, 即将整个数组递归分隔.

  1. 分隔 divide: 将数组递归分成两部分子数组, 直到每部分只剩下一个元素为止
  2. 攻克 conquer: 使用分治算法排序每个子数组
  3. 合并 merge: 将排序好的子数组有序合并在一起

第一阶段: 将数组递归分隔 (partition) 成左右两部分:

merge sort partition pass

第二阶段, 将子数组合并在一起:

merge sort merge pass

归并排序的实现

#![allow(unused)]

fn main() {
/// 对于元素个数为 `N` 的数组, 自顶向下的归并排序 (top-down merge sort)
/// 最多使用 `N log(N)` 次比较以及 `6N log(N)` 次元素访问操作.
#[inline]
pub fn topdown_merge_sort<T>(arr: &mut [T])
where
    T: PartialOrd + Clone,
{
    if arr.is_empty() {
        return;
    }
    sort(arr, 0, arr.len() - 1);
}

/// 排序 `arr[low..=high]` 部分.
fn sort<T>(arr: &mut [T], low: usize, high: usize)
where
    T: PartialOrd + Clone,
{
    if low >= high {
        return;
    }

    let middle = low + (high - low) / 2;

    // 递归排序左侧部分数组
    sort(arr, low, middle);
    // 递归排序右侧部分数组
    sort(arr, middle + 1, high);

    // 合并左右两侧部分数组
    if arr[middle] > arr[middle + 1] {
        merge(arr, low, middle, high);
    }
}

/// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组.
///
/// 它不是原地合并.
#[allow(clippy::needless_range_loop)]
fn merge<T>(arr: &mut [T], low: usize, middle: usize, high: usize)
where
    T: PartialOrd + Clone,
{
    // 辅助数组, 先将数组复制一份.
    let aux = arr[low..=high].to_vec();

    // 再合并回原数组.
    let mut i = low;
    let mut j = middle + 1;

    for k in low..=high {
        if i > middle {
            arr[k] = aux[j - low].clone();
            j += 1;
        } else if j > high {
            arr[k] = aux[i - low].clone();
            i += 1;
        } else if aux[j - low] < aux[i - low] {
            arr[k] = aux[j - low].clone();
            j += 1;
        } else {
}

归并排序的特点

  • 归并排序的时间复杂度是 O(n log(n)), 空间复杂度是 O(N)

元素较少时, 使用插入排序

在排序阶段, 如果数组元素较少时仍然使用递归的归并排序的话, 并不划算, 因为会有大量的递归分支被调用, 还可能导致栈溢出. 为此我们设置一个常量, CUTOFF=24, 当数组元素个数小于它时, 直接使用插入排序.

另外, 我们还在递归调用之前, 创建了辅助数组 aux, 这样就可以在合并时重用这个数组, 以减少内存的分配.

#![allow(unused)]
fn main() {
            i += 1;
        }
    }
}

/// 对于元素数较少的数组, 使用插入排序
pub fn insertion_merge_sort<T>(arr: &mut [T])
where
    T: PartialOrd + Clone,
{
    if arr.is_empty() {
        return;
    }
    let cutoff: usize = 24;
    let mut aux = arr.to_vec();
    sort_cutoff_with_insertion(arr, 0, arr.len() - 1, cutoff, &mut aux);
}

/// 排序 `arr[low..=high]` 部分, 如果元数较少, 就使用插入排序.
fn sort_cutoff_with_insertion<T>(
    arr: &mut [T],
    low: usize,
    high: usize,
    cutoff: usize,
    aux: &mut Vec<T>,
) where
    T: PartialOrd + Clone,
{
    if low >= high {
        return;
    }

    if high - low <= cutoff {
        insertion_sort(&mut arr[low..=high]);
        return;
    }

    let middle = low + (high - low) / 2;

    // 递归排序左侧部分数组
    sort_cutoff_with_insertion(arr, low, middle, cutoff, aux);
    // 递归排序右侧部分数组
    sort_cutoff_with_insertion(arr, middle + 1, high, cutoff, aux);

    // 合并左右两侧部分数组
    if arr[middle] > arr[middle + 1] {
        merge_with_aux(arr, low, middle, high, aux);
    }
}

/// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组.
///
/// 它不是原地合并.
#[allow(clippy::needless_range_loop)]
fn merge_with_aux<T>(arr: &mut [T], low: usize, middle: usize, high: usize, aux: &mut [T])
where
    T: PartialOrd + Clone,
{
    // 辅助数组, 先将数组复制一份.
    for index in low..=high {
        aux[index].clone_from(&arr[index]);
    }

    // 再合并回原数组.
    let mut i = low;
    let mut j = middle + 1;

    for k in low..=high {
        if i > middle {
            arr[k] = aux[j].clone();
            j += 1;

/// 其思路是, 先将前 i 个元素调整为增序的, 随着 i 从 0 增大到 n, 整个序列就变得是增序了.
pub fn insertion_sort<T>(arr: &mut [T])
where
    T: PartialOrd,
{
    let len = arr.len();
    for i in 1..len {
        for j in (1..=i).rev() {
            if arr[j - 1] > arr[j] {
                arr.swap(j - 1, j);
            } else {
                break;
            }
        }
    }
}
}

元素较少时, 使用希尔排序

这个方法是基于以上方法, 用希尔排序来代替插入排序, 可以得到更好的性能. 而且 CUTOFF 值也可以更大一些. 经过几轮测试发现, 对于希尔排序来说, CUTOFF 的取值位于 [64..92] 之间时, 性能较好.

#![allow(unused)]
fn main() {
            arr[k] = aux[i].clone();
            i += 1;
        } else if aux[j] < aux[i] {
            arr[k] = aux[j].clone();
            j += 1;
        } else {
            arr[k] = aux[i].clone();
            i += 1;
        }
    }
}

/// 对于元素数较少的数组, 使用希尔排序
pub fn shell_merge_sort<T>(arr: &mut [T])
where
    T: PartialOrd + Clone,
{
    if arr.is_empty() {
        return;
    }

    let cutoff: usize = 72;
    let mut aux = arr.to_vec();
    sort_cutoff_with_shell(arr, 0, arr.len() - 1, cutoff, &mut aux);
}

/// 排序 `arr[low..=high]` 部分, 如果元数较少, 就使用希尔排序.
fn sort_cutoff_with_shell<T>(
    arr: &mut [T],
    low: usize,
    high: usize,
    cutoff: usize,
    aux: &mut Vec<T>,
) where
    T: PartialOrd + Clone,
{

    let middle = low + (high - low) / 2;

    // 递归排序左侧部分数组
    sort_cutoff_with_insertion(arr, low, middle, cutoff, aux);
    // 递归排序右侧部分数组
    sort_cutoff_with_insertion(arr, middle + 1, high, cutoff, aux);

    // 合并左右两侧部分数组
    if arr[middle] > arr[middle + 1] {
        merge_with_aux(arr, low, middle, high, aux);
    }
}

/// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组.
///
/// 它不是原地合并.
#[allow(clippy::needless_range_loop)]
fn merge_with_aux<T>(arr: &mut [T], low: usize, middle: usize, high: usize, aux: &mut [T])
where
    T: PartialOrd + Clone,
{
    // 辅助数组, 先将数组复制一份.
    for index in low..=high {
        aux[index].clone_from(&arr[index]);
    }

    // 再合并回原数组.
    let mut i = low;
    let mut j = middle + 1;

    for k in low..=high {
        if i > middle {
            arr[k] = aux[j].clone();
            j += 1;

/// Shell sort is a simple extension to insertion sort that allows exchanging
/// elements that far apart.
///
/// It produces partially sorted array (h-sorted array).
pub fn shell_sort<T>(arr: &mut [T])
where
    T: PartialOrd,
{
    const FACTOR: usize = 3;
    let len = arr.len();

    // 计算第一个 gap 的值, 大概是 len/3
    let mut h = 1;
    while h < len / FACTOR {
        h = FACTOR * h + 1;
    }

    while h >= 1 {
        // 使用插入排序, 将 `arr[0..h]` 排序好
        for i in h..len {
            let mut j = i;
            while j >= h && arr[j - h] > arr[j] {
                arr.swap(j - h, j);
                j -= h;
            }
        }

        h /= FACTOR;
    }
}
}

迭代形式实现的归并排序

迭代形式的归并排序, 又称为自下而上的归并排序 (bottom-up merge sort). 它的步骤如下:

  • 将连续的 2 个元素比较并合并在一起
  • 将连续的 4 个元素比较并合并在一起
  • 重复以上过程, 直到所有元素合并在一起

下面的流程图展示了一个简单的操作示例:

bottom-up merge sort

对应的代码实现如下:

#![allow(unused)]
fn main() {
    // 递归排序左侧部分数组
    sort_cutoff_with_shell(arr, low, middle, cutoff, aux);
    // 递归排序右侧部分数组
    sort_cutoff_with_shell(arr, middle + 1, high, cutoff, aux);

    // 合并左右两侧部分数组
    if arr[middle] > arr[middle + 1] {
        merge_with_aux(arr, low, middle, high, aux);
    }
}

/// 迭代形式的归并排序, 自底向上 bottom-up merge sort
pub fn bottom_up_merge_sort<T>(arr: &mut [T])
where
    T: PartialOrd + Clone,
{
    let len = arr.len();
    if len < 2 {
        return;
    }

    let mut aux = arr.to_vec();

    // 开始排序的数组大小, 从 1 到 len / 2
    // current_size 的取值是 1, 2, 4, 8, ...
    let mut current_size = 1;

    while current_size < len {
        // 归并排序的数组左侧索引
        let mut left_start = 0;

        // 子数组的起始点不同, 这样就可以遍历整个数组.
        // left_start 的取值是 0, 2 * current_size, 4 * current_size, ...
        // right_end 的取值是 2 * current_size, 4 * current_size, 6 * current_size, ...
        while left_start < len - 1 {
            let middle = (left_start + current_size - 1).min(len - 1);

    sort_cutoff_with_insertion(arr, middle + 1, high, cutoff, aux);

    // 合并左右两侧部分数组
    if arr[middle] > arr[middle + 1] {
        merge_with_aux(arr, low, middle, high, aux);
    }
}

/// 合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组.
///
/// 它不是原地合并.
#[allow(clippy::needless_range_loop)]
fn merge_with_aux<T>(arr: &mut [T], low: usize, middle: usize, high: usize, aux: &mut [T])
where
    T: PartialOrd + Clone,
{
    // 辅助数组, 先将数组复制一份.
    for index in low..=high {
        aux[index].clone_from(&arr[index]);
    }

    // 再合并回原数组.
    let mut i = low;
    let mut j = middle + 1;

    for k in low..=high {
        if i > middle {
            arr[k] = aux[j].clone();
            j += 1;
        } else if j > high {
            arr[k] = aux[i].clone();
            i += 1;
        } else if aux[j] < aux[i] {
            arr[k] = aux[j].clone();
}

三路归并排序 3-way merge sort

默认实现的归并排序, 是将数组分成左右两部分分别排序. 三路归并排序, 是将数组分成左中右三部分分别排序.

#![allow(unused)]
fn main() {
/// 三路归并排序
pub fn three_way_merge_sort<T>(arr: &mut [T])
where
    T: PartialOrd + Clone,
{
    if arr.is_empty() {
        return;
    }
    let mut aux = arr.to_vec();
    three_way_sort(arr, 0, arr.len() - 1, &mut aux);
}

/// 三路排序 `arr[low..=high]`
fn three_way_sort<T>(arr: &mut [T], low: usize, high: usize, aux: &mut Vec<T>)
where
    T: PartialOrd + Clone,
{
    // 如果数组长度小于2, 就返回.
    if low + 1 > high {
        return;
    }

    // 将数组分成三部分
    let middle1 = low + (high - low) / 3;
    let middle2 = low + 2 * ((high - low) / 3);

    // 递归排序各部分数组
    three_way_sort(arr, low, middle1, aux);
    three_way_sort(arr, middle1 + 1, middle2, aux);
    three_way_sort(arr, middle2 + 1, high, aux);

    // 合并三部分数组
    three_way_merge(arr, low, middle1, middle2, high, aux);
}

/// 合并 `arr[low..=middle1]`, `arr[middle1+1..=middle2]` 以及 `arr[middle2+1..=high]` 三个子数组.
///
/// 它不是原地合并.
#[allow(clippy::needless_range_loop)]
fn three_way_merge<T>(
    arr: &mut [T],
    low: usize,
    middle1: usize,
    middle2: usize,
    high: usize,
    aux: &mut [T],
) where
    T: PartialOrd + Clone,
{
    // 辅助数组, 先将数组复制一份.
    for index in low..=high {
        aux[index].clone_from(&arr[index]);
    }

    // 再合并回原数组.
    let mut i = low;
    let mut j = middle1 + 1;
    let mut k = middle2 + 1;
    let mut l = low;

    // 首先合并较小的子数组
    while i <= middle1 && j <= middle2 && k <= high {
        let curr_index = if aux[i] < aux[j] && aux[i] < aux[k] {
            &mut i
        } else if aux[j] < aux[k] {
            &mut j
        } else {
            &mut k
        };
        arr[l].clone_from(&aux[*curr_index]);
        *curr_index += 1;
        l += 1;
    }

    // 然后合并剩余部分的子数组
    while i <= middle1 && j <= middle2 {
        let curr_index = if aux[i] < aux[j] {
            &mut i
        } else {
            &mut j
        };
        arr[l].clone_from(&aux[*curr_index]);
        *curr_index += 1;
        l += 1;
    }

    while j <= middle2 && k <= high {
        let curr_index = if aux[j] < aux[k] {
            &mut j
        } else {
            &mut k
        };
        arr[l].clone_from(&aux[*curr_index]);
        *curr_index += 1;
        l += 1;
    }

    while i <= middle1 && k <= high {
        let curr_index = if aux[i] < aux[k] {
            &mut i
        } else {
            &mut k
        };
        arr[l].clone_from(&aux[*curr_index]);
        *curr_index += 1;
        l += 1;
    }

    while i <= middle1 {
        arr[l].clone_from(&aux[i]);
        i += 1;
        l += 1;
    }
    while j <= middle2 {
        arr[l].clone_from(&aux[j]);
        j += 1;
        l += 1;
    }
    while k <= high {
        arr[l].clone_from(&aux[k]);
        k += 1;
        l += 1;
    }
}
}

三路归并排序的特点:

  • 时间复杂度是 O(n log_3(n)), 空间复杂度是 O(n)
  • 但因为在 merge_xx() 函数中引入了更多的比较操作, 其性能可能更差

原地归并排序

原地归并排序, 是替代了辅助数组, 它使用类似插入排序的方式, 将后面较大的元素交换到前面合适的位置. 尽管省去了辅助数组, 但是因为移动元素的次数显著境多了, 其性能表现并不好.

下面的流程图展示了一个原地归并排序的示例:

in place merge sort merge pass

#![allow(unused)]
fn main() {
/// 原地归并排序
///
/// 尽管它不需要辅助数组, 但它的性能差得多, 时间复杂度是 `O(N^2 Log(N))`, 而默认实现的归并排序的
/// 时间复杂度是 `O(N Log(N))`.
pub fn in_place_merge_sort<T>(arr: &mut [T])
where
    T: PartialOrd,
{
    if arr.is_empty() {
        return;
    }
    sort_in_place(arr, 0, arr.len() - 1);
}

/// 原地排序 `arr[low..=high]`
fn sort_in_place<T>(arr: &mut [T], low: usize, high: usize)
where
    T: PartialOrd,
{
    if low >= high {
        return;
    }

    let middle = low + (high - low) / 2;
    sort_in_place(arr, low, middle);
    sort_in_place(arr, middle + 1, high);

    if arr[middle] > arr[middle + 1] {
        merge_in_place(arr, low, middle, high);
    }
}

/// 原地合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组.
fn merge_in_place<T>(arr: &mut [T], mut low: usize, mut middle: usize, high: usize)
where
    T: PartialOrd,
{
    let mut low2 = middle + 1;
    debug_assert!(arr[middle] > arr[low2]);

    while low <= middle && low2 <= high {
        if arr[low] <= arr[low2] {
            low += 1;
        } else {
            // 将所有元素右移, 并将 arr[low2] 插入到 arr[low] 所在位置. 这一步很慢.
            for index in (low..low2).rev() {
                arr.swap(index, index + 1);
            }

            // 更新所有的索引
            low += 1;
            middle += 1;
            low2 += 1;
        }
    }
}

}

原地归并排序的特点:

  • 时间复杂度度是 O(N^2 Log(N)), 空间复杂度是 O(1)
  • c++ 的标准库里有实现类似的算法, 参考 inplace_merge

优化原地归并排序

上面的原地归并排序, 每次只移动一个元素间隔. 类似于希尔排序, 我们可以增大移动元素的间隔 (gap), 来减少 移动元素的次数.

#![allow(unused)]
fn main() {
/// 对原地归并排序的优化
///
/// 它不需要辅助数组, 它参考了希尔排序, 通过调整元素间隔 gap 减少元素移动次数.
pub fn in_place_shell_merge_sort<T>(arr: &mut [T])
where
    T: PartialOrd,
{
    if arr.is_empty() {
        return;
    }
    sort_in_place_with_shell(arr, 0, arr.len() - 1);
}

/// 原地排序 `arr[low..=high]`
fn sort_in_place_with_shell<T>(arr: &mut [T], low: usize, high: usize)
where
    T: PartialOrd,
{
    if low >= high {
        return;
    }

    let middle = low + (high - low) / 2;
    sort_in_place_with_shell(arr, low, middle);
    sort_in_place_with_shell(arr, middle + 1, high);

    merge_in_place_with_shell(arr, low, high);
}

/// 使用希尔排序的方式原地合并 `arr[low..=middle]` 以及 `arr[middle+1..=high]` 两个子数组.
///
/// 时间复杂度 `O(N Log(N))`, 空间复杂度 `O(1)`
fn merge_in_place_with_shell<T>(arr: &mut [T], low: usize, high: usize)
where
    T: PartialOrd,
{
    #[must_use]
    #[inline]
    const fn next_gap(gap: usize) -> usize {
        const FACTOR: usize = 2;
        if gap == 1 {
            0
        } else {
            gap.div_ceil(FACTOR)
        }
    }
    let len = high - low + 1;
    let mut gap = next_gap(len);

    while gap > 0 {
        for i in low..=(high - gap) {
            let j = i + gap;
            // 每次间隔多个元素进行比较和交换.
            if arr[i] > arr[j] {
                arr.swap(i, j);
            }
        }
        gap = next_gap(gap);
    }
}
}
  • 时间复杂度度是 O(n log(n) log(n)), 空间复杂度是 O(1)