计数排序 Counting Sort

计数排序不是基于比较值的排序算法.

计数排序的步骤

计数排序的实现分几个阶段:

  • 首先遍历输入数组, 计算元素的取值范围
  • 生成计数数组, 其元素个数基于元素的取值范围确定
  • 遍历输入数组, 根据每个元素与最小元素的差值作为索引, 更新计数数组
  • 更新计数数组, 使之成为前缀和数组
  • 初始化输出数组
  • 从最后一个元素开始遍历输入数组, 每个元素都存放 num
    • 计算当前元素与最小元素的差值 delta_index
    • 从计数数组中取得该元素的索引值 let num_index = count_arr[delta_index]
    • 更新输出数组, 将 num 放到相应的位置, arr[num_index - 1] = num
    • 并更新计数数组, 将里面的索引值减1, count_arr[delta_index] -= 1

计数排序的实现

下面的算法限制了输入元素是 i32:

#![allow(unused)]
fn main() {
}

#[allow(clippy::cast_sign_loss)]
pub fn counting_sort(arr: &mut [i32]) {
    if arr.is_empty() {
        return;
    }
    let min_num: i32 = arr.iter().min().copied().unwrap_or_default();
    let max_num: i32 = arr.iter().max().copied().unwrap_or_default();
    // 计算数值范围
    let range: i32 = max_num - min_num;
    let size: usize = range as usize + 1;

    // 构造计数数组
    let mut count_arr = vec![0_usize; size];

    // 遍历输入数组, 更新计数数组
    for &num in arr.iter() {
        let delta: i32 = num - min_num;
        let index: usize = delta as usize;
        count_arr[index] += 1;
    }

    // 生成累积数组, prefix sum array
    for i in 1..size {
        count_arr[i] += count_arr[i - 1];
    }

    // 构造输入数组, 只读的
    let input_arr: Vec<i32> = arr.to_vec();

    // 从输入数组的右侧向左侧遍历, 这样实现的是稳定排序.
    for &num in input_arr.iter().rev() {
        // 计算当前值与最小值的差.
        let delta: i32 = num - min_num;
        let delta_index = delta as usize;
        // 从 count_arr 里取出该数值的相对位置
        let num_index: usize = count_arr[delta_index];
        // 把 num 放在对应的位置
        arr[num_index - 1] = num;

        // 同时更新 count_arr, 使之计数减1, 这样的话下一个相同数值的元素的索引值就被左移了一位.
        count_arr[delta_index] -= 1;
}

下面的代码对计数排序加入了泛型的支持, 注意它的类型 T 有很多限制:

#![allow(unused)]
fn main() {
use std::collections::BTreeMap;
use std::ops::Sub;

pub fn counting_sort_generic<T>(arr: &mut [T])
where
    T: Copy + Default + Ord + Sub<Output=T> + TryInto<usize>,
{
    if arr.is_empty() {
        return;
    }
    let min_num: T = arr.iter().min().copied().unwrap_or_default();
    let max_num: T = arr.iter().max().copied().unwrap_or_default();
    // 计算数值范围
    let range: T = max_num - min_num;
    let size: usize = range.try_into().unwrap_or_default() + 1;

    // 构造计数数组
    let mut count_arr = vec![0_usize; size];

    // 遍历数组, 更新计数数组
    for num in arr.iter() {
        let delta: T = *num - min_num;
        let index: usize = delta.try_into().unwrap_or_default();
        count_arr[index] += 1;
    }

    // 生成累积数组, prefix sum array
    for i in 1..size {
        count_arr[i] += count_arr[i - 1];
    }

    // 构造输入数组, 只读的
    let input_arr = arr.to_vec();

    for &num in input_arr.iter().rev() {
        let delta: T = num - min_num;
        let delta_index: usize = delta.try_into().unwrap_or_default();
        // 从 count_arr 里取出该数值的相对位置
        let num_index: usize = count_arr[delta_index];
        // 把 num 放在对应的位置
        arr[num_index - 1] = num;

        // 同时更新 count_arr, 使之计数减1, 这样的话下一个相同数值的元素的索引值就被左移了一位.
        count_arr[delta_index] -= 1;
    }
}

计数排序的特点

  • 空间复杂度是 O(n + m), n 是输入数组的大小, m 是计数数组的大小, 也就是元素的数值范围
  • 时间复杂度是 O(n + m)
  • 计数排序是稳定排序, 但不是原地排序 (in-place sorting)
  • 如果数组中的元素值所处的范围比较大的话, 计数排序的效率就比较低
  • 它需要较多的额外空间来存储中间值
  • 计数排序要比归并排序和快速排序等基于比较元素值的排序算法都要快
  • 计数排序不惧怕有重复的元素, 但是如果元素的取值范围比较大的话, 其效率就很低

使用 map 作为计数数组的容器

上面实现的计数排序, 其计数数组对于元素的取值范围很敏感, 甚至计数数组中可能有很多的值都是0, 它们 都被浪费掉了.

对此, 我们可以做一些优化, 使用 map 来存储计数数组中的值.

#![allow(unused)]
fn main() {
use std::collections::BTreeMap;
use std::ops::Sub;

pub fn counting_sort_generic<T>(arr: &mut [T])
where
    T: Copy + Default + Ord + Sub<Output=T> + TryInto<usize>,
{
    if arr.is_empty() {
        return;
    }
    let min_num: T = arr.iter().min().copied().unwrap_or_default();
    let max_num: T = arr.iter().max().copied().unwrap_or_default();
    // 计算数值范围
    let range: T = max_num - min_num;
    let size: usize = range.try_into().unwrap_or_default() + 1;

    // 构造计数数组
    let mut count_arr = vec![0_usize; size];

    // 遍历数组, 更新计数数组
    for num in arr.iter() {
        let delta: T = *num - min_num;
        let index: usize = delta.try_into().unwrap_or_default();
        count_arr[index] += 1;
    }

    // 生成累积数组, prefix sum array
    for i in 1..size {
        count_arr[i] += count_arr[i - 1];
    }

    // 构造输入数组, 只读的
    let input_arr = arr.to_vec();

    for &num in input_arr.iter().rev() {
        let delta: T = num - min_num;
        let delta_index: usize = delta.try_into().unwrap_or_default();
        // 从 count_arr 里取出该数值的相对位置
        let num_index: usize = count_arr[delta_index];
        // 把 num 放在对应的位置
        arr[num_index - 1] = num;

        // 同时更新 count_arr, 使之计数减1, 这样的话下一个相同数值的元素的索引值就被左移了一位.
        count_arr[delta_index] -= 1;
    }
}

#[allow(clippy::cast_sign_loss)]
pub fn counting_sort(arr: &mut [i32]) {
    if arr.is_empty() {
        return;
    }
    let min_num: i32 = arr.iter().min().copied().unwrap_or_default();
    let max_num: i32 = arr.iter().max().copied().unwrap_or_default();
    // 计算数值范围
    let range: i32 = max_num - min_num;
    let size: usize = range as usize + 1;

    // 构造计数数组
    let mut count_arr = vec![0_usize; size];

    // 遍历输入数组, 更新计数数组
    for &num in arr.iter() {
        let delta: i32 = num - min_num;
        let index: usize = delta as usize;
        count_arr[index] += 1;
    }

    // 生成累积数组, prefix sum array
    for i in 1..size {
        count_arr[i] += count_arr[i - 1];
    }

    // 构造输入数组, 只读的
    let input_arr: Vec<i32> = arr.to_vec();

    // 从输入数组的右侧向左侧遍历, 这样实现的是稳定排序.
    for &num in input_arr.iter().rev() {
        // 计算当前值与最小值的差.
        let delta: i32 = num - min_num;
        let delta_index = delta as usize;
        // 从 count_arr 里取出该数值的相对位置
        let num_index: usize = count_arr[delta_index];
        // 把 num 放在对应的位置
        arr[num_index - 1] = num;

        // 同时更新 count_arr, 使之计数减1, 这样的话下一个相同数值的元素的索引值就被左移了一位.
        count_arr[delta_index] -= 1;
    }
}

#[allow(clippy::cast_sign_loss)]
pub fn counting_sort_with_map(arr: &mut [i32]) {
    if arr.is_empty() {
        return;
    }

    // 构造字典, 存储元素的频率
    let mut freq_map: BTreeMap<i32, usize> = BTreeMap::new();
    // 遍历输入数组, 更新计数数组
    for &num in arr.iter() {
        *freq_map.entry(num).or_default() += 1;
    }

    // 遍历字典
    let mut i = 0;
    for (num, freq) in freq_map {
        for _j in 0..freq {
            arr[i] = num;
            i += 1;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::{counting_sort, counting_sort_generic, counting_sort_with_map};

    #[test]
    fn test_counting_sort() {
        let mut list = [0, 5, 3, 2, 2];
        counting_sort(&mut list);
        assert_eq!(list, [0, 2, 2, 3, 5]);

        let mut list = [-2, -5, -45];
        counting_sort(&mut list);
        assert_eq!(list, [-45, -5, -2]);

        let mut list = [
            -998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -983_833,
            -987_905, -980_069, -977_640,
        ];
        counting_sort(&mut list);
        assert_eq!(
            list,
            [
                -998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -987_905,
                -983_833, -980_069, -977_640,
            ]
        );
    }

    #[test]
    fn test_counting_sort_generic() {
        let mut list = [0, 5, 3, 2, 2];
        counting_sort_generic(&mut list);
        assert_eq!(list, [0, 2, 2, 3, 5]);

        let mut list = [-2, -5, -45];
        counting_sort_generic(&mut list);
        assert_eq!(list, [-45, -5, -2]);

        let mut list = [
            -998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -983_833,
            -987_905, -980_069, -977_640,
        ];
        counting_sort_generic(&mut list);
        assert_eq!(
            list,
            [
                -998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -987_905,
                -983_833, -980_069, -977_640,
            ]
        );
    }

    #[test]
    fn test_counting_sort_with_map() {
        let mut list = [0, 5, 3, 2, 2];
        counting_sort_with_map(&mut list);
        assert_eq!(list, [0, 2, 2, 3, 5]);

        let mut list = [-2, -5, -45];
        counting_sort_with_map(&mut list);
        assert_eq!(list, [-45, -5, -2]);

        let mut list = [
            -998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -983_833,
            -987_905, -980_069, -977_640,
        ];
        counting_sort_with_map(&mut list);
        assert_eq!(
            list,
            [
                -998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -987_905,
                -983_833, -980_069, -977_640,
            ]
        );
    }
}
}

该算法的特点是

  • 时间复杂度是 O(n log(n)), 空间复杂度是 O(n)
  • 即使输入数组的取值范围较大, 也不成问题