问题描述

一个开发人员写了一段明显有问题的排序代码,大致如下:

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;

public class Test {

    public static void main(String[] args) throws InterruptedException {
        //测试数据: List里放Map,按Map里的name字段排序
        HashMap a = new HashMap();
        a.put("name", "二");
        HashMap b = new HashMap();
        b.put("name", "一");
        HashMap c = new HashMap();
        c.put("name", "一");
        HashMap d = new HashMap();
        d.put("name", "四");
        HashMap e = new HashMap();
        e.put("name", "二");
        HashMap f = new HashMap();
        f.put("name", "三");
        ArrayList> list = new ArrayList<>();
        list.add(a);
        list.add(b);
        list.add(c);
        list.add(d);
        list.add(e);
        list.add(f);

        //排序:明显有问题,因为只返回-1和0,也就是比较的时候永远是小于等于
        Collections.sort(list, new Comparator>() {
            @Override
            public int compare(HashMap o1, HashMap o2) {
                String n1 = o1.get("name");
                String n2 = o2.get("name");
                if (n1.equals("一")) {
                    return -1;
                }
                if (n1.equals("二") && !n2.equals("一")) {
                    return -1;
                }
                if (n1.equals("三") && !"一二".contains(n2)) {
                    return -1;
                }
                if (n1.equals("四") && !"一二三".contains(n2)) {
                    return -1;
                }
                return 0;
            }
        });

        for(HashMap x : list) {
            System.out.print(x.get("name"));
        }

    }
}

按理这个排序是有问题的,但是不管怎么改变测试数据,排序结果都是对的(测试数据量较小),上面代码的输出结果如下,用的jdk是1.7:

一一二二三四

但是,生产上是有问题的。

分析

Collections.sort,最终调用了Arrays.sort,在1.7中,Arrays.sort做了修改。

    public static  void sort(T[] a, Comparator c) {
        if (c == null) {
            sort(a);
        } else {
            if (LegacyMergeSort.userRequested)
                legacyMergeSort(a, c);
            else
                TimSort.sort(a, 0, a.length, c, null, 0, 0);
        }
    }

如果配置了java.util.Arrays.useLegacyMergeSort这个参数,那么就走老的LegacyMergeSort,否则就走新的TimSort。

我们在代码里加上下面一句话,输出结果就是乱序的,这符合预期。

System.setProperty("java.util.Arrays.useLegacyMergeSort", "true");

检查了一下生产上JVM的参数,果然加了这个参数。

但是为什么走TimSort的结果是对的呢?继续分析TimSort的代码,发现有一个特殊情况的处理:

        // If array is small, do a "mini-TimSort" with no merges
        if (nRemaining < MIN_MERGE) { //MIN_MERGE是32
            int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
            binarySort(a, lo, hi, lo + initRunLen, c);
            return;
        }

也就是在数组小于32的时候,进入这个里面,然后没有归并。那我们先来测试一下大于32的情况。

public class Test { 
    public static void main(String[] args) throws InterruptedException {    
        ArrayList> list = new ArrayList<>();
        String[] xx = {"一","二","三","四"};
        for(int i = 0; i < 35; i++) {
            HashMap x = new HashMap();
            x.put("name", xx[(i+17)%4]);
            list.add(x);
        }
        Collections.sort(list, new Comparator>() {
            @Override
            public int compare(HashMap o1, HashMap o2) {
                String n1 = o1.get("name");
                String n2 = o2.get("name");
                if (n1.equals("一")) {
                    return -1;
                }
                if (n1.equals("二") && !n2.equals("一")) {
                    return -1;
                }
                if (n1.equals("三") && !"一二".contains(n2)) {
                    return -1;
                }
                if (n1.equals("四") && !"一二三".contains(n2)) {
                    return -1;
                }
                return 0;
            }
        });

        for(HashMap x : list) {
            System.out.print(x.get("name"));
        }
    }
}

这次果然翻车了。

一一一一二二二二二三三三三三四四四四一一一一二二二二三三三三四四四四四

我们通过代码来看一下为什么小于32的时候排序成功了。

首先,我们的比较函数,只有在真正小于或者等于情况下返回了-1,其余情况返回了0,包括大于的情况也返回了0。

比如

两个值 结果
一一 -1
一二 -1
三二 0
四四 -1
三一 0

为了简化,下面用阿拉伯数字代替

以211423为例,

        if (nRemaining < MIN_MERGE) {
            int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
            binarySort(a, lo, hi, lo + initRunLen, c);
            return;
        }

第一步,是找到严格递增或者递减的最大长度,如果是升序,就不处理,降序的话,就reverse。

211423经过处理后变成了112 423,最大递减长度为3(因为1和1相比的结果为-1,所以也被当作严格递减),然后211被reverse成112

private static  int countRunAndMakeAscending(T[] a, int lo, int hi,
                                                Comparator c) {
    assert lo < hi;
    int runHi = lo + 1;
    if (runHi == hi)
        return 1;
    // Find end of run, and reverse range if descending
    if (c.compare(a[runHi++], a[lo]) < 0) { // Descending
        while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
            runHi++;
        reverseRange(a, lo, runHi);
    } else {                              // Ascending
        while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
            runHi++;
    }
    return runHi - lo;
}

接下来,从第四个位置开始,找到它的位置,移动数据,让每一个数字找到合适的位置,具体的代码如下:

    private static  void binarySort(T[] a, int lo, int hi, int start,
                                       Comparator c) {
        assert lo <= start && start <= hi;
        if (start == lo)
            start++;
        for ( ; start < hi; start++) {
            T pivot = a[start];

            // Set left (and right) to the index where a[start] (pivot) belongs
            int left = lo;
            int right = start;
            assert left <= right;
            /*
             * Invariants:
             *   pivot >= all in [lo, left).
             *   pivot <  all in [right, start).
             */
            while (left < right) {
                int mid = (left + right) >>> 1;
                if (c.compare(pivot, a[mid]) < 0)
                    right = mid;
                else
                    left = mid + 1;
            }
            assert left == right;

            int n = start - left;  // The number of elements to move
            // Switch is just an optimization for arraycopy in default case
            switch (n) {
                case 2:  a[left + 2] = a[left + 1];
                case 1:  a[left + 1] = a[left];
                         break;
                default: System.arraycopy(a, left, a, left + 1, n);
            }
            a[left] = pivot;
        }
    }

对于112423的移动过程如下:

第一次:112 4 23, 在左边找到合适4的位置,结果为1124 23

第二次:1124 2 3, 在左边找到2合适的位置,结果11224 3

第三次:11224 3,在左边找到3合适的位置,结果为112234,结束

在整个函数中,我们发现了一个问题,那就是只用到了c.compare(pivot, a[mid]) < 0,而大于0和等于0的情况没有用到,而我们的比较函数正好是返回小于0的时候是正确的,所以并不会影响这个函数的执行结果。也就是说,只要真正小于的时候返回了-1,不小于的时候返回了0或者1,对这个函数是没有影响的,正因为如此这个函数是个稳定排序。

但是在countRunAndMakeAscending这个函数里用到了>=0。我们看一下这种情况,也就是数组的开头是递增的时候,会用到>=0

private static  int countRunAndMakeAscending(T[] a, int lo, int hi,
                                                Comparator c) {
    assert lo < hi;
    int runHi = lo + 1;
    if (runHi == hi)
        return 1;
    // Find end of run, and reverse range if descending
    if (c.compare(a[runHi++], a[lo]) < 0) { // Descending
        while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
            runHi++;
        reverseRange(a, lo, runHi);
    } else {                              // Ascending
        while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
            runHi++;
    }
    return runHi - lo;
}

假设输入的是1234123,前边2和1相比结果是0,3和2也是0,4和3也是0,1和4是-1,所以最大递增序列是1234,同时不用reverse,传给下一个函数的输入为1234 123,结果三次插入,结果也是对的。

总结

综上分析可以得出结论,就是因为在jdk 1.7中,如果数组小于32个元素,加入对于小于的比较都是-1, 其他的都是0,那么结果是正确的,这是因为算法本身的特性。但是大于32时,就不对了,会看到分段排好序了,这是因为归并的时候比较结果都是0,导致没有做归并。

其实sort的Comparator是有坑的,必须把所有情况都考虑周到,而且要满足以下特性:

1 ) 自反性: x , y 的比较结果和 y , x 的比较结果相反。
2 ) 传递性: x > y , y > z ,则 x > z 。
3 ) 对称性: x = y ,则 x , z 比较结果和 y , z 比较结果相同。

上面的Comparator如果要写的对,应该这么写,把所有情况列出来,当然也可以通过一些条件简化,但是简化的后果就是上面的结果,需要充分测试。

        Collections.sort(list, new Comparator>() {
            @Override
            public int compare(HashMap o1, HashMap o2) {
                String n1 = o1.get("name");
                String n2 = o2.get("name");
                if (n1.equals("一") && n2.equals("一")) {
                    return 0;
                }
                if (n1.equals("一") && n2.equals("二")) {
                    return -1;
                }
                if (n1.equals("一") && n2.equals("三")) {
                    return -1;
                }
                if (n1.equals("一") && n2.equals("四")) {
                    return -1;
                }
                if (n1.equals("二") && n2.equals("一")) {
                    return 1;
                }
                ......

            }
        });