Comparator 使用详解及部分代码解析

最近使用Comparator 比较多,觉得挺有意思的,简单好用,方便。记录一下;

1.介绍

Comparator 里面的方法,我们可以分成几大类。下面的测试代码,也是根据这几大类,测试其中部分方法,剩余方法类似,就不在使用。

  1. compare - 正常比较
  2. naturalOrder - 自然比较,根据实体类定义的Comparable
  3. nullsFirst、nullsLast - 将null值放在第一个或者最后一个
  4. comparing、comparingLong、comparingInt、comparingDouble - 常用比较方法,可以指定参数类型
  5. reversed、reverseOrder - 反转
  6. thenComparing、thenComparingDouble、thenComparingInt、thenComparingLong - 多次比较
其中comparing 有一个重载方法thenComparing,有两个重载方法

2.测试

package com.study.testCompare;

import org.junit.Test;

import java.math.BigDecimal;
import java.util.*;

public class ComparatorTest {

    @Test
    public void testCompare() {
        List personList = new ArrayList<>();
        personList.add(new Person("a", new BigDecimal(12), 170));
        personList.add(new Person("b", new BigDecimal(24), 175, new Student(27)));
        personList.add(new Person("c", new BigDecimal(12), 177));
        personList.add(new Person("a", new BigDecimal(12), 177));
        personList.add(new Person("b", new BigDecimal(54), 174, new Student(19)));

        // naturalOrder
        System.out.println("naturalOrder : ");
        personList.sort(Comparator.naturalOrder());
        personList.forEach(System.out::println);

        // comparing 1.0
        Optional optional = personList.stream().max(Comparator.comparing(Person::getAge));
        System.out.println("comparing 1.0 : get max age " + optional.get().toString() + "\n");

        // comparing 2.1
        optional = personList.stream().max(Comparator.comparing(Person::getName, Comparator.reverseOrder()));
        System.out.println("comparing 2.1 : get min name " + optional.get().toString() + "\n");

        // comparing 2.2
        optional = personList.stream().max(Comparator.comparing(Person::getName, String::compareTo));
        System.out.println("comparing 2.2 : get max name " + optional.get().toString() + "\n");

        // comparing 2.3
        optional = personList.stream().max(Comparator.comparing(Person::getStudent, (o1, o2) -> new Student().compare(o1, o2)));
        System.out.println("comparing 2.3 : get max student.age " + optional.get().toString() + "\n");


        // thenComparing 1.0
        System.out.println("thenComparing 1.0 : ");
        personList.sort(Comparator.comparing(Person::getAge).thenComparing(Person::getHeight));
        personList.forEach(System.out::println);


        // thenComparing 2.0
        System.out.println("thenComparing 2.0 : ");
        personList.sort(Comparator.comparing(Person::getAge).thenComparing(Person::getHeight).thenComparing(Person::getName));
        personList.forEach(System.out::println);


        // 升序
        System.out.println("升序 : ");
        personList.sort(Comparator.comparingInt(Person::getHeight));
        personList.forEach(System.out::println);


        // 降序
        System.out.println("降序 : ");
        personList.sort(Comparator.comparingInt(Person::getHeight).reversed());
        personList.forEach(System.out::println);

        // nullsLast
        System.out.println("nullsLast : ");
        personList.sort(Comparator.nullsLast(Comparator.comparing(Person::getName)));
        personList.forEach(System.out::println);

        // nullsLast
        System.out.println("nullsLast : ");
        personList.sort(Comparator.nullsLast(Comparator.comparing(Person::getName)));
        personList.forEach(System.out::println);
    }

}

package com.study.testCompare;

import lombok.Data;

import java.math.BigDecimal;
import java.util.Comparator;

@Data
public class Person implements Comparable {
    private String name;
    private BigDecimal age;
    private Integer height;
    private Student student;

    public Person(String name, BigDecimal age, Integer height) {
        this.name = name;
        this.age = age;
        this.height = height;
        this.student = new Student(0);
    }

    public Person(String name, BigDecimal age, Integer height, Student student) {
        this.name = name;
        this.age = age;
        this.height = height;
        this.student = student;
    }

    @Override
    public int compareTo(Object o) {
        Person p1 = (Person) o;

        if (this.age.equals(p1.age)) {
            return p1.height - this.height;
        }
       return this.age.compareTo(p1.age);
    }
}

@Data
class Student implements Comparator {

    private int age;

    public Student() {
    }

    public Student(int age) {
        this.age = age;
    }

    @Override
    public int compare(Object o1, Object o2) {

        Student p1 = (Student) o1;
        Student p2 = (Student) o2;

        int result = Integer.compare(p1.age, p2.age);

        result = result == 0 ? ((p1.age > p2.age) ? 1 : -1) : result;

        return result;
    }
}

3.源码


package java.util;

import java.io.Serializable;
import java.util.function.Function;
import java.util.function.ToIntFunction;
import java.util.function.ToLongFunction;
import java.util.function.ToDoubleFunction;
import java.util.Comparators;

/**
 * 注解过长,不做翻译
 */
@FunctionalInterface
public interface Comparator {
    /**
     * 最常用的的方法
     * o1 = o2 : return 0;
     * o1 > o2 : return 1;
     * o1 < o2 : return -1;
     */
    int compare(T o1, T o2);

    /**
     * equals 方法
     */
    boolean equals(Object obj);

    /**
     * 排序反转
     */
    default Comparator reversed() {
        return Collections.reverseOrder(this);
    }

    /**
     * 连续 比较,比如先比较一个人的年龄,在比较一个人的身高,在比较一个人的体重
     * @since 1.8
     */
    default Comparator thenComparing(Comparator other) {
        Objects.requireNonNull(other);
        return (Comparator & Serializable) (c1, c2) -> {
            int res = compare(c1, c2);
            return (res != 0) ? res : other.compare(c1, c2);
        };
    }

    /**
     * 连续 比较,入参不同,自定义compara
     * @since 1.8
     */
    default  Comparator thenComparing(
            Function keyExtractor,
            Comparator keyComparator)
    {
        return thenComparing(comparing(keyExtractor, keyComparator));
    }

    /**
     * 连续 比较,入参不同
     * @since 1.8
     */
    default > Comparator thenComparing(
            Function keyExtractor)
    {
        return thenComparing(comparing(keyExtractor));
    }

    /**
     * 连续 比较,指定Int类型
     * @since 1.8
     */
    default Comparator thenComparingInt(ToIntFunction keyExtractor) {
        return thenComparing(comparingInt(keyExtractor));
    }

    /**
     * 连续 比较,指定Long类型
     * @since 1.8
     */
    default Comparator thenComparingLong(ToLongFunction keyExtractor) {
        return thenComparing(comparingLong(keyExtractor));
    }

    /**
     * 连续 比较,指定Double类型
     * @since 1.8
     */
    default Comparator thenComparingDouble(ToDoubleFunction keyExtractor) {
        return thenComparing(comparingDouble(keyExtractor));
    }

    /**
     * 反转排序
     * @since 1.8
     */
    public static > Comparator reverseOrder() {
        return Collections.reverseOrder();
    }

    /**
     * 自然排序(实体类需要实现 Comparable)
     * @since 1.8
     */
    @SuppressWarnings("unchecked")
    public static > Comparator naturalOrder() {
        return (Comparator) Comparators.NaturalOrderComparator.INSTANCE;
    }

    /**
     * 如果有空数据,放在第一个
     * @since 1.8
     */
    public static  Comparator nullsFirst(Comparator comparator) {
        return new Comparators.NullComparator<>(true, comparator);
    }

    /**
     *  如果有空数据,放在最后一个
     * @since 1.8
     */
    public static  Comparator nullsLast(Comparator comparator) {
        return new Comparators.NullComparator<>(false, comparator);
    }

    /**
     * 正常排序(正序,需要实现Comparator),入参两个参数。第二个为Comparator
     * @since 1.8
     */
    public static  Comparator comparing(
            Function keyExtractor,
            Comparator keyComparator)
    {
        Objects.requireNonNull(keyExtractor);
        Objects.requireNonNull(keyComparator);
        return (Comparator & Serializable)
            (c1, c2) -> keyComparator.compare(keyExtractor.apply(c1),
                                              keyExtractor.apply(c2));
    }

    /**
     * 正常排序(正序,需要实现Comparator)。一个参数
     * @since 1.8
     */
    public static > Comparator comparing(
            Function keyExtractor)
    {
        Objects.requireNonNull(keyExtractor);
        return (Comparator & Serializable)
            (c1, c2) -> keyExtractor.apply(c1).compareTo(keyExtractor.apply(c2));
    }

    /**
     * 正常排序 —— int
     * @since 1.8
     */
    public static  Comparator comparingInt(ToIntFunction keyExtractor) {
        Objects.requireNonNull(keyExtractor);
        return (Comparator & Serializable)
            (c1, c2) -> Integer.compare(keyExtractor.applyAsInt(c1), keyExtractor.applyAsInt(c2));
    }

    /**
     * 正常排序 —— long
     * @since 1.8
     */
    public static  Comparator comparingLong(ToLongFunction keyExtractor) {
        Objects.requireNonNull(keyExtractor);
        return (Comparator & Serializable)
            (c1, c2) -> Long.compare(keyExtractor.applyAsLong(c1), keyExtractor.applyAsLong(c2));
    }

    /**
     * * 正常排序 —— double
     * @since 1.8
     */
    public static Comparator comparingDouble(ToDoubleFunction keyExtractor) {
        Objects.requireNonNull(keyExtractor);
        return (Comparator & Serializable)
            (c1, c2) -> Double.compare(keyExtractor.applyAsDouble(c1), keyExtractor.applyAsDouble(c2));
    }
}

4.解析部分方法

源码:

    /**
     * 正常排序(正序,需要实现Comparator)。一个参数
     * @since 1.8
     */
    public static > Comparator comparing(
            Function keyExtractor)
    {
        Objects.requireNonNull(keyExtractor);
        return (Comparator & Serializable)
            (c1, c2) -> keyExtractor.apply(c1).compareTo(keyExtractor.apply(c2));
    }

使用:

        // comparing 1.0
        Optional optional = personList.stream().max(Comparator.comparing(Person::getAge));
        System.out.println("comparing 1.0 : max age " + optional.get().toString() + "\n");

通过debug可知,最终源码中,c1c2是Person对象,在比较的时候,调用了Integer的compareTo
那么可以知道 keyExtractor.apply(c1) = c1.getAge()keyExtractor.apply(c2) = c2.getAge() 。而且源码中明确指出:Function keyExtractor
其中:

  • T为传入类型
  • U为返回类型
    所以我们使用lamble后,传入的PersonTgetAgeU。所以keyExtractor.apply返回的就是Integer类型的age。最终调用对应包装类型的compareTo实现比较

你可能感兴趣的:(java)