手撸golang 基本数据结构与算法 归并排序

缘起

最近阅读<<我的第一本算法书>>(【日】石田保辉;宫崎修一)
本系列笔记拟采用golang练习之

归并排序

归并排序算法会把序列分成长度相同的两个子序列,
当无法继续往下分时(也就是每个子序列中只有一个数据时),
就对子序列进行归并。

归并指的是把两个排好序的子序列合并成一个有序序列。
该操作会一直重复执行,直到所有子序列都归并为一个整体为止。

总的运行时间为O(nlogn),这与前面讲到的堆排序相同。

摘自 <<我的第一本算法书>> 【日】石田保辉;宫崎修一

流程

  1. 给定待排序数组data[N]
  2. 创建缓冲区buffer[N]
  3. 设定初始的归并步长(已排序的子序列长度), span=1, 也就是将每个元素视作已排序的一个子序列
  4. 根据当前步长, 将所有已排序的子序列两两合并到缓冲区

    1. 设定q1为子序列1的头部下标, q2为子序列2的头部下标, p指向缓冲区的对应区域
    2. 比较q1与q2位置的值大小
    3. 如果q1的值小, 则取q1的值, 放入p, 然后q1和p都加1
    4. 如果q2的值小, 则取q2的值, 放入p, 然后q2和p都加1
    5. 如果子序列已经取完, 则复制尚未取完的子序列到p
  5. 交换data与buffer的指针, 将data视为缓冲区, buffer视为待归并数据
  6. span = span*2
  7. 重复步骤3-6, 直到span>N

设计

  • ISorter: 定义排序器接口. 定义值比较函数以兼容任意数值类型, 通过调整比较函数实现倒序排序
  • tMergeSort: 归并排序器, 实现ISorter接口.

单元测试

merge_sort_test.go. 归并排序是比较快的, 因此设定测试规模为10万元素.

package sorting

import (
    "fmt"
    "learning/gooop/sorting"
    "learning/gooop/sorting/merge_sort"
    "math/rand"
    "testing"
    "time"
)

func Test_MergeSort(t *testing.T) {
    fnAssertTrue := func(b bool, msg string) {
        if !b {
            t.Fatal(msg)
        }
    }

    reversed := false
    fnCompare := func(a interface{}, b interface{}) sorting.CompareResult {
        i1 := a.(int)
        i2 := b.(int)

        if i1 < i2 {
            if reversed {
                return sorting.GREATER
            } else {
                return sorting.LESS
            }
        } else if i1 == i2 {
            return sorting.EQUAL
        } else {
            if reversed {
                return sorting.LESS
            } else {
                return sorting.GREATER
            }
        }
    }

    fnTestSorter := func(sorter sorting.ISorter) {
        reversed = false

        // test simple array
        samples := []interface{} { 2,3,1,5,4,7,6 }
        samples = sorter.Sort(samples, fnCompare)
        fnAssertTrue(fmt.Sprintf("%v", samples) == "[1 2 3 4 5 6 7]",  "expecting 1,2,3,4,5,6,7")
        t.Log("pass sorting [2 3 1 5 4 7 6] >> [1 2 3 4 5 6 7]")

        // test 10000 items sorting
        rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
        for plus := 0;plus < 3;plus++ {
            sampleCount := 100 * 1000 + plus
            t.Logf("prepare large array with %v items", sampleCount)
            samples = make([]interface{}, sampleCount)
            for i := 0; i < sampleCount; i++ {
                samples[i] = rnd.Intn(sampleCount * 10)
            }

            t.Logf("sorting large array with %v items", sampleCount)
            t0 := time.Now().UnixNano()
            samples = sorter.Sort(samples, fnCompare)
            cost := time.Now().UnixNano() - t0
            for i := 1; i < sampleCount; i++ {
                fnAssertTrue(fnCompare(samples[i-1], samples[i]) != sorting.GREATER, "expecting <=")
            }
            t.Logf("end sorting large array, cost = %v ms", cost/1000000)
        }

        // test 0-20
        sampleCount := 20
        t.Log("sorting 0-20")
        samples = make([]interface{}, sampleCount)
        for i := 0;i < sampleCount;i++ {
            for {
                p := rnd.Intn(sampleCount)
                if samples[p] == nil {
                    samples[p] = i
                    break
                }
            }
        }
        t.Logf("unsort = %v", samples)

        samples = sorter.Sort(samples, fnCompare)
        t.Logf("sorted = %v", samples)
        fnAssertTrue(fmt.Sprintf("%v", samples) == "[0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]", "expecting 0-20")
        t.Log("pass sorting 0-20")

        // test special
        samples = []interface{} {}
        samples = sorter.Sort(samples, fnCompare)
        t.Log("pass sorting []")

        samples = []interface{} { 1 }
        samples = sorter.Sort(samples, fnCompare)
        t.Log("pass sorting [1]")

        samples = []interface{} { 3,1 }
        samples = sorter.Sort(samples, fnCompare)
        fnAssertTrue(fmt.Sprintf("%v", samples) == "[1 3]",  "expecting 1,3")
        t.Log("pass sorting [1 3]")

        reversed = true
        samples = []interface{} { 2, 3,1 }
        samples = sorter.Sort(samples, fnCompare)
        fnAssertTrue(fmt.Sprintf("%v", samples) == "[3 2 1]",  "expecting 3,2,1")
        t.Log("pass sorting [3 2 1]")
    }

    t.Log("\ntesting MergeSort")
    fnTestSorter(merge_sort.MergeSort)
}

测试输出

  • 归并排序相当的快, 比堆排序还快1倍左右
  • 比冒泡,选择,插入等有指数级的提升, 符合理论分析
  • 代价与堆排序一致, 需要额外分配大小为N的空间, 用做归并缓冲区
  • 比堆排序快的主要原因, 推测是因为堆排序初始化时, 批量push操作, 可能引发多次数组扩容.
$ go test -v merge_sort_test.go 
=== RUN   Test_MergeSort
    merge_sort_test.go:111: 
        testing MergeSort
    merge_sort_test.go:48: pass sorting [2 3 1 5 4 7 6] >> [1 2 3 4 5 6 7]
    merge_sort_test.go:54: prepare large array with 100000 items
    merge_sort_test.go:60: sorting large array with 100000 items
    merge_sort_test.go:67: end sorting large array, cost = 35 ms
    merge_sort_test.go:54: prepare large array with 100001 items
    merge_sort_test.go:60: sorting large array with 100001 items
    merge_sort_test.go:67: end sorting large array, cost = 36 ms
    merge_sort_test.go:54: prepare large array with 100002 items
    merge_sort_test.go:60: sorting large array with 100002 items
    merge_sort_test.go:67: end sorting large array, cost = 33 ms
    merge_sort_test.go:72: sorting 0-20
    merge_sort_test.go:83: unsort = [6 10 8 9 14 1 12 4 19 7 11 16 15 17 0 2 18 3 5 13]
    merge_sort_test.go:86: sorted = [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
    merge_sort_test.go:88: pass sorting 0-20
    merge_sort_test.go:93: pass sorting []
    merge_sort_test.go:97: pass sorting [1]
    merge_sort_test.go:102: pass sorting [1 3]
    merge_sort_test.go:108: pass sorting [3 2 1]
--- PASS: Test_MergeSort (0.12s)
PASS
ok      command-line-arguments  0.127s

ISorter.go

定义排序器接口. 定义值比较函数以兼容任意数值类型, 通过调整比较函数实现倒序排序

package sorting

type ISorter interface {
    Sort(data []interface{}, comparator CompareFunction) []interface{}
}

type CompareFunction func(a interface{}, b interface{}) CompareResult

type CompareResult int
const LESS CompareResult = -1
const EQUAL CompareResult = 0
const GREATER CompareResult = 1

tMergeSort.go

归并排序器, 实现ISorter接口.

package merge_sort

import (
    "learning/gooop/sorting"
)

type tMergeSort struct {}

func newMergeSort() sorting.ISorter {
    return &tMergeSort{}
}

func (me *tMergeSort) Sort(data []interface{}, comparator sorting.CompareFunction) []interface{} {
    if data == nil {
        return nil
    }

    size := len(data)
    if size <= 1 {
        return data
    }

    var result []interface{} = nil
    buffer := make([]interface{}, size)
    for span := 1; span <= size;span *= 2 {
        for i := 0;i < size;i += span * 2 {
            merge(size, data, i, i + span, span, buffer, i, comparator)
        }

        result = buffer
        data, buffer = buffer, data
    }

    if result == nil {
        result = data
    }

    return result
}

// 合并data数组中的两个子序列: [q1:q1+span), [q2:q2+span), 到目标数组result的offset位置
func merge(size int, data []interface{}, q1 int, q2 int, span int, result []interface{}, offset int, comparator sorting.CompareFunction) {
    e1 := min(q1 + span, size)
    e2 := min(q2 + span, size)
    j := -1
    k := -1

    for i := 0;i < span*2;i++ {
        if q1 >= e1 {
            j = q2
            k = e2

        } else if q2 >= e2 {
            j = q1
            k = e1
        }

        if j >= 0 {
            for p := j;p < k;p++ {
                result[offset] = data[p]
                offset++
            }
            break
        }

        v1 := data[q1]
        v2 := data[q2]

        if lessEqual(v1, v2, comparator) {
            result[offset] = v1
            q1++
        } else {
            result[offset] = v2
            q2++
        }
        offset++
    }
}

func lessEqual(v1 interface{}, v2 interface{}, comparator sorting.CompareFunction) bool {
    return comparator(v1, v2) != sorting.GREATER
}

func min(a,b int) int {
    if a <= b {
        return a
    } else {
        return b
    }
}

var MergeSort = newMergeSort()

(end)

你可能感兴趣的:(手撸golang 基本数据结构与算法 归并排序)