java sort排序源码分析(TimSort排序)

入口:

default void sort(Comparator c) {
        Object[] a = this.toArray();
        Arrays.sort(a, (Comparator) c);
        ListIterator i = this.listIterator();
        for (Object e : a) {
            i.next();
            i.set((E) e);
        }
    }

java排序方法调用的Arrays.sort ,传入两个参数,数据数组和comparator对象 

public static  void sort(T[] a, Comparator c) {
        if (c == null) {
            sort(a);
        } else {
            if (LegacyMergeSort.userRequested)
                legacyMergeSort(a, c);
            else
                TimSort.sort(a, 0, a.length, c, null, 0, 0);
        }
    }

在sort方法中,有两种排序算法,传统排序,和TimSort

LegacyMergeSort.userRequested是使用jdk5的传统排序方法。

TimSort是改进后的归并排序,对归并排序在已经反向排好序的输入时表现为O(n^2)的特点做了特别优化。对已经正向排好序的输入减少回溯。对两种情况(一会升序,一会降序)的输入处理比较好(摘自百度百科)。

这里主要讲解TimSort排序

TimSort.sort(a, 0, a.length, c, null, 0, 0);
static  void sort(T[] a, int lo, int hi, Comparator c,
                         T[] work, int workBase, int workLen)

这里传入很多参数:a:数据数组,lo数据第一个元素索引,hi最后一个元素索引,c比较器对象,work工作空间数组,workBase工作空间可用空间,workLen工作集合的大小。

 static  void sort(T[] a, int lo, int hi, Comparator c,
                         T[] work, int workBase, int workLen) {
        //断言错误情况
        assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;
        
        //判断数组长度是否小于2 如果是只有0或1,这种数组通常已经被排序
        int nRemaining  = hi - lo;
        if (nRemaining < 2)
            return;  // Arrays of size 0 and 1 are always sorted
        
        //如果数组长度小于MIN_MERGE(32)则使用二分排序
        // If array is small, do a "mini-TimSort" with no merges
        if (nRemaining < MIN_MERGE) {
            int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
            binarySort(a, lo, hi, lo + initRunLen, c);
            return;
        }

长度小于32时的二分排序

1.数组从头开始寻找顺序片段,直到不满足要求;如果倒序也一直查找直到不满足要求,然后反转。

private static  int countRunAndMakeAscending(T[] a, int lo, int hi,
                                                    Comparator c) {
        assert lo < hi;
        int runHi = lo + 1;
        if (runHi == hi)
            return 1;
        //寻找数组中有序队列
        // Find end of run, and reverse range if descending
        if (c.compare(a[runHi++], a[lo]) < 0) { // Descending
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
                runHi++;
            reverseRange(a, lo, runHi);
        } else {                              // Ascending
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
                runHi++;
        }
        //返回有序片段长度
        return runHi - lo;
    }

2.使用二分查找来排序

private static  void binarySort(T[] a, int lo, int hi, int start,
                                       Comparator c) {
        assert lo <= start && start <= hi;
        if (start == lo)
            start++;
        for ( ; start < hi; start++) {
            T pivot = a[start];

            // Set left (and right) to the index where a[start] (pivot) belongs
            int left = lo;
            int right = start;
            assert left <= right;
            //查找到所需插入位置索引
            while (left < right) {
                int mid = (left + right) >>> 1;
                if (c.compare(pivot, a[mid]) < 0)
                    right = mid;
                else
                    left = mid + 1;
            }
            assert left == right;
            //进行插入(插入位置是1或2时优化)
            int n = start - left;  // The number of elements to move
            // Switch is just an optimization for arraycopy in default case
            switch (n) {
                case 2:  a[left + 2] = a[left + 1];
                case 1:  a[left + 1] = a[left];
                         break;
                default: System.arraycopy(a, left, a, left + 1, n);
            }
            a[left] = pivot;
        }
    }

这个相当于未分片的TimSort 

 长度大于32位时TimSort排序

1.计算出最小分片长度

        /**
         * March over the array once, left to right, finding natural runs,
         * extending short natural runs to minRun elements, and merging runs
         * to maintain stack invariant.
         */
        TimSort ts = new TimSort<>(a, c, work, workBase, workLen);
        int minRun = minRunLength(nRemaining);
        do {
            // Identify next run
            int runLen = countRunAndMakeAscending(a, lo, hi, c);

            // If run is short, extend to min(minRun, nRemaining)
            if (runLen < minRun) {
                int force = nRemaining <= minRun ? nRemaining : minRun;
                binarySort(a, lo, lo + force, lo + runLen, c);
                runLen = force;
            }

            // Push run onto pending-run stack, and maybe merge
            ts.pushRun(lo, runLen);
            ts.mergeCollapse();

            // Advance to find next run
            lo += runLen;
            nRemaining -= runLen;
        } while (nRemaining != 0);

        // Merge all remaining runs to complete sort
        assert lo == hi;
        ts.mergeForceCollapse();
        assert ts.stackSize == 1;

计算出minRun,当n>=32时除2,直到小于32,(如果n为2的N幂,计算出来为16,否则保留最后五位加最后一次移位的r)

private static int minRunLength(int n) {
        assert n >= 0;
        int r = 0;      // Becomes 1 if any 1 bits are shifted off
        while (n >= MIN_MERGE) {
            //&1之后,n为奇数则为1,偶数为0
            r |= (n & 1);
            //右移,相当于除2
            n >>= 1;
        }
        return n + r;
    }

2.do-while

2.1取得最小升序片段长度(如果是降序则反转),这个方法前面写到过

// Identify next run
int runLen = countRunAndMakeAscending(a, lo, hi, c);

2.2如果该长度小于最小分片长度,则用二分查找插入变成满足最小分片长度的升序片段

// If run is short, extend to min(minRun, nRemaining)
if (runLen < minRun) {
    int force = nRemaining <= minRun ? nRemaining : minRun;
    binarySort(a, lo, lo + force, lo + runLen, c);
    runLen = force;
}

 2.3将该序列的起始位置和长度入栈

private void pushRun(int runBase, int runLen) {
        this.runBase[stackSize] = runBase;
        this.runLen[stackSize] = runLen;
        stackSize++;
    }

2.4合并以有有序片段

private void mergeCollapse() {
        while (stackSize > 1) {
            int n = stackSize - 2;
            //第一个片段长度小于后两个相加
            if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
                //如果小于后面第二个长度
                if (runLen[n - 1] < runLen[n + 1])
                    //则将合并位置减一
                    n--;
                mergeAt(n);
            } else if (runLen[n] <= runLen[n + 1]) {
                mergeAt(n);
            } else {
                break; // Invariant is established
            }
        }
    }

合并操作,先查出来两个片段边界元素在另外片段的位置

private void mergeAt(int i) {
        assert stackSize >= 2;
        assert i >= 0;
        assert i == stackSize - 2 || i == stackSize - 3;
        //数据初始化
        int base1 = runBase[i];
        int len1 = runLen[i];
        int base2 = runBase[i + 1];
        int len2 = runLen[i + 1];
        assert len1 > 0 && len2 > 0;
        assert base1 + len1 == base2;

        /*
         * 记录合并后的序列的长度
         */
        runLen[i] = len1 + len2;
        if (i == stackSize - 3) {
            runBase[i + 1] = runBase[i + 2];
            runLen[i + 1] = runLen[i + 2];
        }
        stackSize--;

        /*
         * 查找到run2的第一个元素排序在run1的位置
         */
        int k = gallopRight(a[base2], a, base1, len1, 0, c);
        assert k >= 0;
        base1 += k;
        len1 -= k;
        if (len1 == 0)
            return;

        /*
         * 查找到run1最后一个元素排序在run2的位置
         */
        len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c);
        assert len2 >= 0;
        if (len2 == 0)
            return;
        
        //合并操作
        // Merge remaining runs, using tmp array with min(len1, len2) elements
        if (len1 <= len2)
            mergeLo(base1, len1, base2, len2);
        else
            mergeHi(base1, len1, base2, len2);
    }

 找到两个位置之后,则只需归并中间的字段

 java sort排序源码分析(TimSort排序)_第1张图片

合并方法代码

private void mergeLo(int base1, int len1, int base2, int len2) {
        assert len1 > 0 && len2 > 0 && base1 + len1 == base2;

        // Copy first run into temp array
        T[] a = this.a; // For performance
        T[] tmp = ensureCapacity(len1);
        int cursor1 = tmpBase; // Indexes into tmp array
        int cursor2 = base2;   // Indexes int a
        int dest = base1;      // Indexes int a
        System.arraycopy(a, base1, tmp, cursor1, len1);

        // Move first element of second run and deal with degenerate cases
        a[dest++] = a[cursor2++];
        if (--len2 == 0) {
            System.arraycopy(tmp, cursor1, a, dest, len1);
            return;
        }
        if (len1 == 1) {
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; // Last elt of run 1 to end of merge
            return;
        }

        Comparator c = this.c;  // Use local variable for performance
        int minGallop = this.minGallop;    //  "    "       "     "      "
    outer:
        while (true) {
            int count1 = 0; // Number of times in a row that first run won
            int count2 = 0; // Number of times in a row that second run won

            /*
             * Do the straightforward thing until (if ever) one run starts
             * winning consistently.
             */
            do {
                assert len1 > 1 && len2 > 0;
                if (c.compare(a[cursor2], tmp[cursor1]) < 0) {
                    a[dest++] = a[cursor2++];
                    count2++;
                    count1 = 0;
                    if (--len2 == 0)
                        break outer;
                } else {
                    a[dest++] = tmp[cursor1++];
                    count1++;
                    count2 = 0;
                    if (--len1 == 1)
                        break outer;
                }
            } while ((count1 | count2) < minGallop);

            /*
             * One run is winning so consistently that galloping may be a
             * huge win. So try that, and continue galloping until (if ever)
             * neither run appears to be winning consistently anymore.
             */
            do {
                assert len1 > 1 && len2 > 0;
                count1 = gallopRight(a[cursor2], tmp, cursor1, len1, 0, c);
                if (count1 != 0) {
                    System.arraycopy(tmp, cursor1, a, dest, count1);
                    dest += count1;
                    cursor1 += count1;
                    len1 -= count1;
                    if (len1 <= 1) // len1 == 1 || len1 == 0
                        break outer;
                }
                a[dest++] = a[cursor2++];
                if (--len2 == 0)
                    break outer;

                count2 = gallopLeft(tmp[cursor1], a, cursor2, len2, 0, c);
                if (count2 != 0) {
                    System.arraycopy(a, cursor2, a, dest, count2);
                    dest += count2;
                    cursor2 += count2;
                    len2 -= count2;
                    if (len2 == 0)
                        break outer;
                }
                a[dest++] = tmp[cursor1++];
                if (--len1 == 1)
                    break outer;
                minGallop--;
            } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
            if (minGallop < 0)
                minGallop = 0;
            minGallop += 2;  // Penalize for leaving gallop mode
        }  // End of "outer" loop
        this.minGallop = minGallop < 1 ? 1 : minGallop;  // Write back to field

        if (len1 == 1) {
            assert len2 > 0;
            System.arraycopy(a, cursor2, a, dest, len2);
            a[dest + len2] = tmp[cursor1]; //  Last elt of run 1 to end of merge
        } else if (len1 == 0) {
            throw new IllegalArgumentException(
                "Comparison method violates its general contract!");
        } else {
            assert len2 == 0;
            assert len1 > 1;
            System.arraycopy(tmp, cursor1, a, dest, len1);
        }
    }

这段代码合并代码步骤

2.4.1分配临时片段,用于合并

2.4.2计数count整段合并

注:这里当len1=0抛出异常:Comparison method violates its general contract!,这是在整段合并时,识别到run1有片段应该合并到run2起始位置;但是在合并之前有过判断run1中小于run2第一个元素的片段已经不在合并范围内了,那么合并的run1不可能有片段还在run2的起始值之前(可以看合并的图示更好理解)。所以大家在重写compare方法时需要考虑周全。

以上是对java中timsort排序的一些浅显的解读。

你可能感兴趣的:(java)