import "sort"
sort包中包含的文件
打开sort包我们能看到下面这些文件
这里先看最关键的sort.go。
sort文件中一开始就是最关键的Interface接口,一个满足sort.Interface接口的类型可以被sort包的函数进行排序。
type Interface interface {
Len() int // Len返回集合中元素的个数
Less(i, j int) bool // Less返回索引i的元素是否比索引j的元素小
Swap(i, j int) // Swap交换索引i和j的元素
}
Sort函数用来排序data。使用这个函数排序必须让自定义数据类型实现上面的Interface接口。
它调用1次data.Len确定长度,O(n*log(n))次data.Less和data.Swap排序。不能保证排序的稳定性。
func Sort(data Interface) {
n := data.Len()
quickSort(data, 0, n, maxDepth(n))
}
sort.go中已经为常用的三种数据类型 int、float64、string 提供了排序方法。这里以int为例。此时,如果我们需要为[]int排序,不需要手动实现Interface接口了,直接使用sort.Ints就可以了。
type IntSlice []int
func (p IntSlice) Len() int { return len(p) }
func (p IntSlice) Less(i, j int) bool { return p[i] < p[j] }
func (p IntSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p IntSlice) Sort() { Sort(p) }
func Ints(a []int) { Sort(IntSlice(a)) }
func Float64s(a []float64) { Sort(Float64Slice(a)) }
func Strings(a []string) { Sort(StringSlice(a)) }
下面三个函数使用来判断 int float64 string 三种类型的切片是否有序
func IntsAreSorted(a []int) bool { return IsSorted(IntSlice(a)) }
func Float64sAreSorted(a []float64) bool { return IsSorted(Float64Slice(a)) }
func StringsAreSorted(a []string) bool { return IsSorted(StringSlice(a)) }
func IsSorted(data Interface) bool {
n := data.Len()
for i := n - 1; i > 0; i-- {
if data.Less(i, i-1) {
return false
}
}
return true
}
sort也支持了逆序排序通过Reverse直接调用就可以,不过如果我们对自定义数据类型做排序,可以直接在Less函数中按自己的要求排序。
type reverse struct {
Interface
}
// 与Less相反
func (r reverse) Less(i, j int) bool {
return r.Interface.Less(j, i)
}
// 返回逆序的data
func Reverse(data Interface) Interface {
return &reverse{data}
}
上面说过使用sort.Sort排序是不稳定的,于是sort也实现了Stable来返回稳定的排序。具体stable内部函数还没有细看,待补充。
func Stable(data Interface) {
stable(data, data.Len())
}
分割线
常见的排序算法我们应该已经熟悉的不能再熟悉了。这里我们看一看sort中是使用的哪些方法进行排序处理的。
我们从上面的Sort函数可以看到,调用它后会进入 quickSort(data, 0, n, maxDepth(n)),那么我们看看quickSort是什么。
// quickSort就是真正实现排序的地方
func quickSort(data Interface, a, b, maxDepth int) {
// 如果切片的长度不大于12就用希尔排序(在下面)
// 否则就用下面的方法再判断怎么排序
for b-a > 12 {
// 外面的for循环是对数据的截取,如果深度为0就可以使用堆排序了
if maxDepth == 0 {
// 堆排序
heapSort(data, a, b)
return
}
// 每次循环都会减一层深度
maxDepth--
// 快排的核心,根据某值对数据拆成小于该值和大于该值的两半。
mlo, mhi := doPivot(data, a, b)
// 这里为了节省时间,让数量更多的那一半去做循环,小的那一半做递归。
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo
}
}
// 这部分就是数据集比较小的时候使用的希尔排序,以6为步长,然后就直接插入排序了,反正数也少。
if b-a > 1 {
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}
里面有个maxDepth函数,是用来确定递归深度的。其实也就是一个阈值,用来确定什么时候由快排转换成堆排序。
func maxDepth(n int) int {
var depth int
for i := n; i > 0; i >>= 1 {
depth++
}
return depth * 2
}
快排的关键就是拆分数据,也就是上面的doPivot()
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
// 避免整形溢出
m := int(uint(lo+hi) >> 1)
// 下面这一小段是为了保证快排不会向着最坏时间复杂度越走越远(一直选最大或者最小的值来对数据拆分)
if hi-lo > 40 {
// Tukey's ``Ninther,'' median of three medians of three.
s := (hi - lo) / 8
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)
// 下面的大部分工作是为了保证算法的高效,因为数据中可能包含大量重复的数字,
// 如果选择它为比较值就直接把所有的重复值放在一起,之后不会再判断该值。
// 下面的注释不再翻译。
// Invariants are:
// data[lo] = pivot (set up by ChoosePivot)
// data[lo < i < a] < pivot
// data[a <= i < b] <= pivot
// data[b <= i < c] unexamined
// data[c <= i < hi-1] > pivot
// data[hi-1] >= pivot
pivot := lo
a, c := lo+1, hi-1
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
for {
for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
}
for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
}
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
data.Swap(b, c-1)
b++
c--
}
// If hi-c<3 then there are duplicates (by property of median of nine).
// Let be a bit more conservative, and set border to 5.
protect := hi-c < 5
if !protect && hi-c < (hi-lo)/4 {
// Lets test some points for equality to pivot
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
// m-lo = (hi-lo)/2 > 6
// b-lo > (hi-lo)*3/4-1 > 8
// ==> m < b ==> data[m] <= pivot
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
if protect {
// Protect against a lot of duplicates
// Add invariant:
// data[a <= i < b] unexamined
// data[b <= i < c] = pivot
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
// Swap pivot into middle
data.Swap(pivot, b-1)
return b - 1, c
}
下面是插入排序,实现很简单。
func insertionSort(data Interface, a, b int) {
for i := a + 1; i < b; i++ {
for j := i; j > a && data.Less(j, j-1); j-- {
data.Swap(j, j-1)
}
}
}
紧跟着是堆排序。
// 堆排序——调整堆
func siftDown(data Interface, lo, hi, first int) {
root := lo
for {
child := 2*root + 1
if child >= hi {
break
}
if child+1 < hi && data.Less(first+child, first+child+1) {
child++
}
if !data.Less(first+root, first+child) {
return
}
data.Swap(first+root, first+child)
root = child
}
}
// 堆排序
func heapSort(data Interface, a, b int) {
first := a
lo := 0
hi := b - a
for i := (hi - 1) / 2; i >= 0; i-- {
siftDown(data, i, hi, first)
}
for i := hi - 1; i >= 0; i-- {
data.Swap(first, first+i)
siftDown(data, lo, i, first)
}
}
最后我们来举个栗子吧。
package main
import (
"fmt"
"sort"
)
type point struct {
x,y int
}
type ps []point
func Newps() ps {
slc := ps{{x:4,y:2},{x:3,y:3},{x:8,y:8},{x:2,y:5},{x:1,y:7},{x:3,y:2}}
return slc
}
func (s ps)Len() int {
return len(s)
}
func (s ps)Less(i,j int) bool {
if s[i].x == s[j].x {
return s[i].y <= s[j].y
}
return s[i].x < s[j].x
}
func (s ps)Swap(i,j int) {
s[i],s[j] = s[j],s[i]
}
func main() {
intslc := []int{4,3,6,8,2,1,9,5,7}
fmt.Println(intslc)
sort.Ints(intslc)
fmt.Println(intslc)
// [4 3 6 8 2 1 9 5 7]
// [1 2 3 4 5 6 7 8 9]
floatslc := []float64{4.4,3.2,6.6,5.1,1.9,8.7,2,0.2}
fmt.Println(floatslc)
sort.Float64s(floatslc)
fmt.Println(floatslc)
// [4.4 3.2 6.6 5.1 1.9 8.7 2 0.2]
// [0.2 1.9 2 3.2 4.4 5.1 6.6 8.7]
strslc := []string{"lazyboy", "chen7", "golang", "sort", "library"}
fmt.Println(strslc)
sort.Strings(strslc)
fmt.Println(strslc)
sort.Sort(sort.Reverse(sort.StringSlice(strslc)))
fmt.Println(strslc)
// [lazyboy chen7 golang sort library]
// [chen7 golang lazyboy library sort]
// [sort library lazyboy golang chen7]
if sort.IntsAreSorted(intslc) {
fmt.Println("intslc is sorted")
}
// intslc is sorted
arr := Newps()
fmt.Println(arr)
sort.Sort(arr)
fmt.Println(arr)
// [{4 2} {3 3} {8 8} {2 5} {1 7} {3 2}]
// [{1 7} {2 5} {3 2} {3 3} {4 2} {8 8}]
}
参考资料:
记录每天解决的小问题,积累起来去解决大问题