目录
1、reduce / collect
2、sum / min / max / count / average / summaryStatistics
3、AbstractTask
4、ReduceTask
5、AbstractShortCircuitTask
6、FindTask
7、MatchTask
8、Spliterator
本篇博客继续上一篇《Java8 Stream API 之 IntPipeline(二) 源码解析》 讲解reduce / collect等其他方法的实现。
这两个方法可用于实现获取元素的最小值,最大值,总和和平均值等类似场景,参考下面一节,这两个方法的实现如下:
@Override
public final int reduce(int identity, IntBinaryOperator op) {
//identity标识起始值,如果不传默认使用第一个元素作为起始值
return evaluate(ReduceOps.makeInt(identity, op));
}
@Override
public final OptionalInt reduce(IntBinaryOperator op) {
return evaluate(ReduceOps.makeInt(op));
}
@Override
public final R collect(Supplier supplier,
ObjIntConsumer accumulator,
BiConsumer combiner) {
Objects.requireNonNull(combiner);
BinaryOperator operator = (left, right) -> {
combiner.accept(left, right);
return left;
};
return evaluate(ReduceOps.makeInt(supplier, accumulator, operator));
}
public static TerminalOp
makeInt(int identity, IntBinaryOperator operator) {
Objects.requireNonNull(operator);
class ReducingSink
implements AccumulatingSink, Sink.OfInt {
private int state;
@Override
public void begin(long size) {
//identity是方法的入参
state = identity;
}
@Override
public void accept(int t) {
//把state作为参数传递给operator方法,并更新state为执行结果
state = operator.applyAsInt(state, t);
}
@Override
public Integer get() {
return state;
}
@Override
//并行时调用,将不同线程处理的结果做合并
public void combine(ReducingSink other) {
accept(other.state);
}
}
return new ReduceOp(StreamShape.INT_VALUE) {
@Override
public ReducingSink makeSink() {
return new ReducingSink();
}
};
}
private static abstract class ReduceOp>
implements TerminalOp {
private final StreamShape inputShape;
ReduceOp(StreamShape shape) {
inputShape = shape;
}
public abstract S makeSink();
@Override
public StreamShape inputShape() {
return inputShape;
}
@Override
public R evaluateSequential(PipelineHelper helper,
Spliterator spliterator) {
//串行处理时调用
return helper.wrapAndCopyInto(makeSink(), spliterator).get();
}
@Override
public R evaluateParallel(PipelineHelper helper,
Spliterator spliterator) {
//并行处理时调用
return new ReduceTask<>(this, helper, spliterator).invoke().get();
}
}
private interface AccumulatingSink>
extends TerminalSink {
public void combine(K other);
}
public static TerminalOp
makeInt(IntBinaryOperator operator) {
Objects.requireNonNull(operator);
class ReducingSink
implements AccumulatingSink, Sink.OfInt {
private boolean empty;
private int state;
public void begin(long size) {
empty = true;
state = 0;
}
@Override
public void accept(int t) {
if (empty) {
//接受的第一个元素
empty = false;
state = t;
}
else {
state = operator.applyAsInt(state, t);
}
}
@Override
public OptionalInt get() {
return empty ? OptionalInt.empty() : OptionalInt.of(state);
}
@Override
public void combine(ReducingSink other) {
if (!other.empty) //other非空则执行合并
accept(other.state);
}
}
return new ReduceOp(StreamShape.INT_VALUE) {
@Override
public ReducingSink makeSink() {
return new ReducingSink();
}
};
}
//combiner表示并行处理时合并不同线程处理结果的逻辑
public static TerminalOp
makeInt(Supplier supplier,
ObjIntConsumer accumulator,
BinaryOperator combiner) {
Objects.requireNonNull(supplier);
Objects.requireNonNull(accumulator);
Objects.requireNonNull(combiner);
class ReducingSink extends Box
implements AccumulatingSink, Sink.OfInt {
@Override
public void begin(long size) {
state = supplier.get();
}
@Override
public void accept(int t) {
//注意state并不会更新
accumulator.accept(state, t);
}
@Override
public void combine(ReducingSink other) {
//执行合并的逻辑
state = combiner.apply(state, other.state);
}
}
return new ReduceOp(StreamShape.INT_VALUE) {
@Override
public ReducingSink makeSink() {
return new ReducingSink();
}
};
}
private static abstract class Box {
U state;
Box() {} // Avoid creation of special accessor
public U get() {
return state;
}
}
@Override
public final int sum() {
return reduce(0, Integer::sum);
}
@Override
public final OptionalInt min() {
return reduce(Math::min);
}
@Override
public final OptionalInt max() {
return reduce(Math::max);
}
@Override
public final long count() {
return mapToLong(e -> 1L).sum();
}
@Override
public final OptionalDouble average() {
long[] avg = collect(() -> new long[2], //初始的state是一个长度为2的long数组
(ll, i) -> { //state就是ll,i表示流中的待处理元素
ll[0]++; //统计元素个数
ll[1] += i; //统计总的值
},
(ll, rr) -> { //并行处理时不同线程的处理结果累加
ll[0] += rr[0];
ll[1] += rr[1];
});
//collect方法是返回state值,即构造的长度为2的数组,索引为0的元素大于0说明流中元素不为空
return avg[0] > 0
? OptionalDouble.of((double) avg[1] / avg[0])
: OptionalDouble.empty();
}
@Override
public final IntSummaryStatistics summaryStatistics() {
//state是一个IntSummaryStatistics实例,每次遍历流中元素时都调用其accept方法
return collect(IntSummaryStatistics::new, IntSummaryStatistics::accept,
IntSummaryStatistics::combine);
}
public class IntSummaryStatistics implements IntConsumer {
private long count;
private long sum;
private int min = Integer.MAX_VALUE;
private int max = Integer.MIN_VALUE;
public IntSummaryStatistics() { }
@Override
public void accept(int value) {
++count; //元素个数
sum += value; //总和
min = Math.min(min, value); //最小值
max = Math.max(max, value); //最大值
}
public void combine(IntSummaryStatistics other) {
//同其他线程的处理结果累加
count += other.count;
sum += other.sum;
min = Math.min(min, other.min);
max = Math.max(max, other.max);
}
public final long getCount() {
return count;
}
public final long getSum() {
return sum;
}
public final int getMin() {
return min;
}
public final int getMax() {
return max;
}
public final double getAverage() {
return getCount() > 0 ? (double) getSum() / getCount() : 0.0d;
}
@Override
public String toString() {
return String.format(
"%s{count=%d, sum=%d, min=%d, average=%f, max=%d}",
this.getClass().getSimpleName(),
getCount(),
getSum(),
getMin(),
getAverage(),
getMax());
}
}
AbstractTask是并行流处理的基类,其类继承关系如下:
我们以reduce / collect方法对应的并行流处理实现类 ReduceTask为例来说明其实现。AbstractTask定义的属性如下:
//关联的流
protected final PipelineHelper helper;
/**
* 关联流的Spliterator实现
*/
protected Spliterator spliterator;
/**
* 一个子任务的处理的流元素的个数,按照总元素个数除以4倍的核数来估算
*/
protected long targetSize; // may be laziliy initialized
/**
* 左子任务,如果非空,则右子任务也非空
*/
protected K leftChild;
/**
* 右子任务,如果非空,则右子任务也非空
*/
protected K rightChild;
/**
* 任务执行的结果
*/
private R localResult;
其构造方法如下:
//用来创建根节点
protected AbstractTask(PipelineHelper helper,
Spliterator spliterator) {
super(null);
this.helper = helper;
this.spliterator = spliterator;
this.targetSize = 0L; //默认为0,惰性初始化,在实际的并行处理开始时才会计算
}
//用来创建子任务节点
protected AbstractTask(K parent,
Spliterator spliterator) {
super(parent);
this.spliterator = spliterator;
this.helper = parent.helper;
this.targetSize = parent.targetSize;
}
AbstractTask只有两个抽象方法需要子类实现,如下:
makeChild用来创建子任务,doLeaf用来计算子节点的执行结果,其核心方法是完成实际并行处理的compute方法,如下:
public void compute() {
Spliterator rs = spliterator, ls; // right, left spliterators
long sizeEstimate = rs.estimateSize();
long sizeThreshold = getTargetSize(sizeEstimate);
boolean forkRight = false;
@SuppressWarnings("unchecked") K task = (K) this;
//如果当前元素的总个数大于sizeThreshold,说明可以进一步切分子任务
//trySplit不等于null说明切分成功,trySplit返回一半子任务,rs对应另一半待执行的子任务
while (sizeEstimate > sizeThreshold && (ls = rs.trySplit()) != null) {
K leftChild, rightChild, taskToFork;
//创建左右子任务
task.leftChild = leftChild = task.makeChild(ls);
task.rightChild = rightChild = task.makeChild(rs);
//当前线程会处理其中一个子任务,只有另外一个子任务会提交到线程池处理,所以pendingCount是1
task.setPendingCount(1);
if (forkRight) {
forkRight = false;
//由当前线程继续切分左子任务
rs = ls;
task = leftChild;
//将右子任务提交到线程池
taskToFork = rightChild;
}
else {
forkRight = true;
//由当前线程继续切分右子任务
task = rightChild;
//将左子任务提交到线程池
taskToFork = leftChild;
}
taskToFork.fork();
sizeEstimate = rs.estimateSize();
}
//无法继续切分了,则执行子任务并设置子任务的执行结果
task.setLocalResult(task.doLeaf());
//将当前子任务标记为已完成,当某个父任务下的子任务都执行完成会回调onCompletion方法将左右子任务都置为null
task.tryComplete();
}
protected final long getTargetSize(long sizeEstimate) {
long s;
return ((s = targetSize) != 0 ? s : //如果targetSize不为0则直接返回,否则通过suggestTargetSize初始化
(targetSize = suggestTargetSize(sizeEstimate)));
}
public static long suggestTargetSize(long sizeEstimate) {
long est = sizeEstimate / LEAF_TARGET;
return est > 0L ? est : 1L;
}
static final int LEAF_TARGET = ForkJoinPool.getCommonPoolParallelism() << 2;
protected void setLocalResult(R localResult) {
this.localResult = localResult;
}
@Override
public void onCompletion(CountedCompleter> caller) {
spliterator = null;
leftChild = rightChild = null;
}
ReduceTask继承自AbstractTask,用于实现并行的reduce或者collect方法,其调用如下:
其实现如下:
@SuppressWarnings("serial")
private static final class ReduceTask>
extends AbstractTask> {
private final ReduceOp op;
//helper和spliterator都是调用reduce或者collect方法的流及其Spliterator实现
//由evaluateParallel方法使用
ReduceTask(ReduceOp op,
PipelineHelper helper,
Spliterator spliterator) {
super(helper, spliterator);
this.op = op;
}
//创建子任务makeChild方法使用
ReduceTask(ReduceTask parent,
Spliterator spliterator) {
super(parent, spliterator);
this.op = parent.op;
}
@Override
protected ReduceTask makeChild(Spliterator spliterator) {
//创建一个新的子任务
return new ReduceTask<>(this, spliterator);
}
@Override
protected S doLeaf() {
//任务无法进一步切分了,则需要执行该任务,wrapAndCopyInto方法会遍历spliterator中包含的函数,传递给Sink
return helper.wrapAndCopyInto(op.makeSink(), spliterator);
}
@Override
//左右子任务节点都执行完成后会回调此方法
public void onCompletion(CountedCompleter> caller) {
if (!isLeaf()) {
//获取左子任务节点的执行结果
S leftResult = leftChild.getLocalResult();
//同右子任务节点的执行结果做合并
leftResult.combine(rightChild.getLocalResult());
//设置当前父任务的执行结果
setLocalResult(leftResult);
}
//调用父类方法,将spliterator, left and right child置为null
super.onCompletion(caller);
}
}
//是否叶子节点,左右子任务节点都为null时为叶子节点
protected boolean isLeaf() {
return leftChild == null;
}
AbstractShortCircuitTask继承自AbstractTask,表示一个在满足特定条件后会终止流元素遍历的并行任务,如findAny方法,找到一个满足条件的元素就终止遍历返回true,其实现如下:
@SuppressWarnings("serial")
abstract class AbstractShortCircuitTask>
extends AbstractTask {
/**
* 执行的结果
*/
protected final AtomicReference sharedResult;
/**
* 遍历任务是否被取消了
*/
protected volatile boolean canceled;
/**
* 创建根节点任务
*/
protected AbstractShortCircuitTask(PipelineHelper helper,
Spliterator spliterator) {
super(helper, spliterator);
sharedResult = new AtomicReference<>(null);
}
/**
* 创建子节点任务
*/
protected AbstractShortCircuitTask(K parent,
Spliterator spliterator) {
super(parent, spliterator);
sharedResult = parent.sharedResult;
}
/**
* 返回默认值
*/
protected abstract R getEmptyResult();
/**
* 执行并行任务的核心逻辑,重写了父类的逻辑
*/
@Override
public void compute() {
Spliterator rs = spliterator, ls;
long sizeEstimate = rs.estimateSize();
//计算单个子任务的元素个数
long sizeThreshold = getTargetSize(sizeEstimate);
boolean forkRight = false;
@SuppressWarnings("unchecked") K task = (K) this;
AtomicReference sr = sharedResult;
R result;
while ((result = sr.get()) == null) {
//如果未获取满足条件的结果
if (task.taskCanceled()) { //如果任务已取消,则返回默认值
result = task.getEmptyResult();
break;
}
//如果当前任务的元素个数较少或者无法继续切分了,则调用doLeaf执行当前任务并终止循环
if (sizeEstimate <= sizeThreshold || (ls = rs.trySplit()) == null) {
result = task.doLeaf();
break;
}
K leftChild, rightChild, taskToFork;
//创建左右子任务节点
task.leftChild = leftChild = task.makeChild(ls);
task.rightChild = rightChild = task.makeChild(rs);
task.setPendingCount(1);
if (forkRight) {
forkRight = false;
//继续切分左子任务,将右子任务提交到线程池处理
rs = ls;
task = leftChild;
taskToFork = rightChild;
}
else {
//继续切分右子任务,将左子任务提交到线程池处理
forkRight = true;
task = rightChild;
taskToFork = leftChild;
}
taskToFork.fork();
sizeEstimate = rs.estimateSize();
}
//保存任务的执行结果
task.setLocalResult(result);
//将当前任务标记为已执行完成
task.tryComplete();
}
/**
* 设置执行结果
*/
protected void shortCircuit(R result) {
if (result != null)
sharedResult.compareAndSet(null, result);
}
/**
* 保存执行结果
*/
@Override
protected void setLocalResult(R localResult) {
if (isRoot()) {
//如果是根节点,则设置sharedResult
if (localResult != null)
sharedResult.compareAndSet(null, localResult);
}
else
//非根节点,调用父类的setLocalResult
super.setLocalResult(localResult);
}
/**
* 获取执行结果
*/
@Override
public R getRawResult() {
return getLocalResult();
}
@Override
public R getLocalResult() {
if (isRoot()) {
//根节点时,获取sharedResult中的执行结果
R answer = sharedResult.get();
return (answer == null) ? getEmptyResult() : answer;
}
else
return super.getLocalResult();
}
/**
* 将当前任务标记为已取消
*/
protected void cancel() {
canceled = true;
}
/**
* 判断当前任务是否已取消,如果没有则向上遍历判断父节点任务是否被取消了
*/
protected boolean taskCanceled() {
boolean cancel = canceled;
if (!cancel) {
//cancel为false时则遍历父节点,看有没有父节点的canceled是否为true
for (K parent = getParent(); !cancel && parent != null; parent = parent.getParent())
cancel = parent.canceled;
}
return cancel;
}
protected void cancelLaterNodes() {
// Go up the tree, cancel right siblings of this node and all parents
for (@SuppressWarnings("unchecked") K parent = getParent(), node = (K) this;
parent != null;
node = parent, parent = parent.getParent()) {
//向上遍历直到根节点,将所有的右子节点设置为已取消
if (parent.leftChild == node) {
K rightSibling = parent.rightChild;
if (!rightSibling.canceled)
rightSibling.cancel(); //如果未取消,则取消掉
}
}
}
}
//父节点为null,则是根节点
protected boolean isRoot() {
return getParent() == null;
}
protected K getParent() {
return (K) getCompleter();
}
FindTask用来实现findFirst / findAny的并行处理,其调用如下:
其实现如下:
@SuppressWarnings("serial")
private static final class FindTask
extends AbstractShortCircuitTask> {
private final FindOp op;
//evaluateParallel方法使用
FindTask(FindOp op,
PipelineHelper helper,
Spliterator spliterator) {
super(helper, spliterator);
this.op = op;
}
//下面的创建子任务makeChild方法使用
FindTask(FindTask parent, Spliterator spliterator) {
super(parent, spliterator);
this.op = parent.op;
}
@Override
protected FindTask makeChild(Spliterator spliterator) {
return new FindTask<>(this, spliterator);
}
@Override
protected O getEmptyResult() { //返回默认值
return op.emptyValue;
}
//只有mustFindFirst为true时调用此方法
private void foundResult(O answer) {
if (isLeftmostNode())
//如果是最左边的子任务节点,则设置sharedResult
shortCircuit(answer);
else
//当前节点不是最左边的子任务节点,则取消后续的子任务节点
cancelLaterNodes();
}
@Override
protected O doLeaf() {
//执行子任务并获取结果
O result = helper.wrapAndCopyInto(op.sinkSupplier.get(), spliterator).get();
if (!op.mustFindFirst) { //如果mustFindFirst为false,即为findAny方法
if (result != null)
shortCircuit(result); //尝试设置sharedResult,如果成功则会终止子任务的切割
return null;
}
else {
//如果mustFindFirst为true,即为findFirst方法
if (result != null) {
foundResult(result);
return result;
}
else
return null;
}
}
@Override
//子任务都执行完成时,回调此方法
public void onCompletion(CountedCompleter> caller) {
if (op.mustFindFirst) {
for (FindTask child = leftChild, p = null; child != p;
p = child, child = rightChild) {
//child先是左子任务节点,然后是右子任务节点,然后因为p等于child终止遍历
O result = child.getLocalResult();
if (result != null && op.presentPredicate.test(result)) {
//如果子任务节点的执行结果满足要求,则保存执行结果并终止遍历
setLocalResult(result);
foundResult(result);
break;
}
}
}
//调用父类方法,将leftChild,rightChild等置为null
super.onCompletion(caller);
}
}
//判断是否最左边的节点
protected boolean isLeftmostNode() {
@SuppressWarnings("unchecked")
K node = (K) this;
while (node != null) {
K parent = node.getParent();
if (parent != null && parent.leftChild != node)
return false; //如果node不是parent的左节点则返回false
node = parent;
}
return true;
}
MatchTask用来实现anyMatch / allMatch / noneMatch的并行处理,其实现如下:
private static final class MatchTask
extends AbstractShortCircuitTask> {
private final MatchOp op;
/**
* evaluateParallel方法使用,创建根节点
*/
MatchTask(MatchOp op, PipelineHelper helper,
Spliterator spliterator) {
super(helper, spliterator);
this.op = op;
}
/**
* makeChild方法使用,创建子任务节点
*/
MatchTask(MatchTask parent, Spliterator spliterator) {
super(parent, spliterator);
this.op = parent.op;
}
@Override
protected MatchTask makeChild(Spliterator spliterator) {
return new MatchTask<>(this, spliterator);
}
@Override
protected Boolean doLeaf() {
//执行子任务
boolean b = helper.wrapAndCopyInto(op.sinkSupplier.get(), spliterator).getAndClearState();
if (b == op.matchKind.shortCircuitResult)
//如果执行结果满足条件,则设置sharedResult,终止子任务切割
shortCircuit(b);
return null;
}
@Override
protected Boolean getEmptyResult() {
return !op.matchKind.shortCircuitResult;
}
}
Spliterator是Java8引入的接口,其定义的方法的调用场景和用途在之前的源码分析中已经陆续的说明了,此处做一个总结,并以ArrayList中Spliterator接口的实现类ArrayListSpliterator进一步说明实现该接口的相关细节。所有必须实现的方法说明如下:
可使用的表示流元素的特点的常量如下:
//流元素遍历时的顺序是固定的,比如List中的元素
public static final int ORDERED = 0x00000010;
//流元素经过去重的
public static final int DISTINCT = 0x00000001;
//流元素已经经过排序了,比如SortedSet中的元素
public static final int SORTED = 0x00000004;
//流元素的个数是有限的
public static final int SIZED = 0x00000040;
//流元素都是非空的
public static final int NONNULL = 0x00000100;
//流元素在遍历的过程中不能被修改
public static final int IMMUTABLE = 0x00000400;
//流元素在遍历的过程中可以被并发的线程安全的修改
public static final int CONCURRENT = 0x00001000;
//表明当前流中的元素是一个切分出来的子流
public static final int SUBSIZED = 0x00004000;
ArrayListSpliterator的实现如下:
@Override
public Spliterator spliterator() {
//expectedModCount传递的是0,在具体遍历时会初始化成一个大于0的值
return new ArrayListSpliterator<>(this, 0, -1, 0);
}
/** Index-based split-by-two, lazily initialized Spliterator */
static final class ArrayListSpliterator implements Spliterator {
private final ArrayList list; //关联的list
private int index; //下一个遍历的元素的索引
private int fence; //允许遍历的最大索引,-1表示无限制,即遍历所有的数组元素
private int expectedModCount; //保存初始化时的ModCount,如果遍历时此属性与list的modCount不一致则抛出异常
/** Create new spliterator covering the given range */
ArrayListSpliterator(ArrayList list, int origin, int fence,
int expectedModCount) {
this.list = list; // OK if null unless traversed
this.index = origin;
this.fence = fence;
this.expectedModCount = expectedModCount;
}
private int getFence() { // initialize fence to size on first use
int hi; // (a specialized variant appears in method forEach)
ArrayList lst;
if ((hi = fence) < 0) { //fence为-1时
if ((lst = list) == null)
hi = fence = 0;
else {
//不为null时初始化expectedModCount和fence
expectedModCount = lst.modCount;
hi = fence = lst.size;
}
}
return hi;
}
public ArrayListSpliterator trySplit() {
//获取中间值
int hi = getFence(), lo = index, mid = (lo + hi) >>> 1;
return (lo >= mid) ? null : // divide range in half unless too small
//返回一个新的实例,遍历的元素范围是从index到mid之间,当前实例的index被修改成mid,元素范围就是mid到hi
new ArrayListSpliterator(list, lo, index = mid,
expectedModCount);
}
public boolean tryAdvance(Consumer super E> action) {
if (action == null)
throw new NullPointerException();
int hi = getFence(), i = index;
if (i < hi) {
//修改index
index = i + 1;
//获取原index对应的元素并调用action
@SuppressWarnings("unchecked") E e = (E)list.elementData[i];
action.accept(e);
if (list.modCount != expectedModCount) //list发生修改了抛出异常
throw new ConcurrentModificationException();
return true;
}
return false;
}
//改写了默认的forEachRemaining实现
public void forEachRemaining(Consumer super E> action) {
int i, hi, mc; // hoist accesses and checks from loop
ArrayList lst; Object[] a;
if (action == null)
throw new NullPointerException();
if ((lst = list) != null && (a = lst.elementData) != null) {
if ((hi = fence) < 0) { //fence为-1时
mc = lst.modCount;
hi = lst.size;
}
else
//fence大于0时
mc = expectedModCount;
if ((i = index) >= 0 && (index = hi) <= a.length) {
//遍历index到fence之间的元素
for (; i < hi; ++i) {
@SuppressWarnings("unchecked") E e = (E) a[i];
action.accept(e);
}
//如果没有发生修改则返回,否则抛出异常
if (lst.modCount == mc)
return;
}
}
throw new ConcurrentModificationException();
}
public long estimateSize() {
return (long) (getFence() - index);
}
public int characteristics() {
return Spliterator.ORDERED | Spliterator.SIZED | Spliterator.SUBSIZED;
}
}