并发编程:自定义并发类:6、自定义的fork/join线程类(拓展ForkJoinWorkerThread)

目录

ForkJoinWorkerThread

简单说明

一、主程序

二、fork/join线程工厂类

三、自定义fork/join线程类

四、分治的任务类

五、执行结果


ForkJoinWorkerThread

该类拓展自Thread类,为其增加了新方法,用于子类拓展:

  • onStart()方法,在创建线程时执行。
  • onTermination()方法,结束时进行资源清理。

ForkJoinPool类使用ForkJoinWorkerThreadFactory的接口实现来创建它(ForkJoinPool)使用的工作线程

简单说明

  1. 我们要创建自定义的ForkJoin线程,就要拓展ForkJoinWorkerThread类(即继承自它)。
  2. 由于线程池使用线程工厂创建,所以要实现ForkJoinWorkerThreadFactory接口,以返回自定义的ForkJoin线程对象。

一、主程序

package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

/**
 *  8.7、实现自定义的fork/join线程类(拓展fork/join线程类的功能)
 * @author jangle
 * @email [email protected]
 * @time 2020年10月3日 下午5:33:20
 * 
 */
public class M {

	public static void main(String[] args) throws Exception {
		// 创建线程工厂
		var factory = new MyWorkerThreadFactory();
		// 使用上述工厂,构建线程池
		var pool = new ForkJoinPool(4, factory, null, false);
		
		int array[] = new int[100000];
		for (int i = 0; i < array.length; i++) {
			array[i] = 1;
		}
		// 创建解决问题的任务对象
		var task = new MyRecursiveTask(array, 0, array.length);
		pool.execute(task);
		task.join();
		pool.shutdown();
		pool.awaitTermination(1, TimeUnit.DAYS);
		System.out.println("Main: resutl:"+task.get());
		System.out.println("Main:结束");

	}

}

二、fork/join线程工厂类

package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory;
import java.util.concurrent.ForkJoinWorkerThread;

/**
 *  ForkJoin线程工厂
 * @author jangle
 * @email [email protected]
 * @time 2020年10月3日 下午5:43:11
 * 
 */
public class MyWorkerThreadFactory implements ForkJoinWorkerThreadFactory {

	@Override
	public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
		return new MyWorkerThread(pool);
	}

}

三、自定义fork/join线程类

package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;

/**
 *  自定义fork/join线程类
 * @author jangle
 * @email [email protected]
 * @time 2020年10月3日 下午5:34:00
 * 
 */
public class MyWorkerThread extends ForkJoinWorkerThread {

	// 用于计算执行了几个任务
	private final static ThreadLocal taskCounter = new ThreadLocal();

	protected MyWorkerThread(ForkJoinPool pool) {
		super(pool);
	}

	@Override
	protected void onStart() {
		super.onStart();
		System.out.println("MyWorkerThread: onStart  getId():" + getId());
		taskCounter.set(0);
	}

	@Override
	protected void onTermination(Throwable exception) {
		System.out.println("MyWorkerThread: onTermination " + getId() + ":" + taskCounter.get());
		super.onTermination(exception);
	}

	/**
	 *  增加任务计数。
	 */
	public void addTask() {
		taskCounter.set(taskCounter.get() + 1);
	}

}

四、分治的任务类

package xyz.jangle.thread.test.n8_7.forkjointhreadfactory;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

/**
 * 分治任务类
 * @author jangle
 * @email [email protected]
 * @time 2020年10月3日 下午5:53:45
 * 
 */
public class MyRecursiveTask extends RecursiveTask {

	private static final long serialVersionUID = 1L;

	private int array[];

	private int start, end;

	public MyRecursiveTask(int[] array, int start, int end) {
		super();
		this.array = array;
		this.start = start;
		this.end = end;
	}

	@Override
	protected Integer compute() {
		Integer ret;
		MyWorkerThread thread = (MyWorkerThread) Thread.currentThread();
		thread.addTask();
		if (end - start <= 100) {
			// 计算
			int add = 0;
			for (int i = start; i < end; i++) {
				add += array[i];
			}
			ret = add;
		} else {
			// 分治
			int mid = (start + end) / 2;
			var task1 = new MyRecursiveTask(array, start, mid);
			var task2 = new MyRecursiveTask(array, mid, end);
			invokeAll(task1, task2);
			ret = addResults(task1, task2);
		}
		try {
			TimeUnit.MILLISECONDS.sleep(10);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		return ret;
	}

	private Integer addResults(MyRecursiveTask task1, MyRecursiveTask task2) {
		int value;
		try {
			value = task1.get().intValue() + task2.get().intValue();
		} catch (InterruptedException | ExecutionException e) {
			e.printStackTrace();
			value = 0;
		}
		return value;
	}

}

五、执行结果

MyWorkerThread: onStart  getId():13
MyWorkerThread: onStart  getId():14
MyWorkerThread: onStart  getId():15
MyWorkerThread: onStart  getId():16
MyWorkerThread: onStart  getId():17
MyWorkerThread: onTermination 15:569
MyWorkerThread: onTermination 16:576
MyWorkerThread: onTermination 13:428
MyWorkerThread: onTermination 17:0
MyWorkerThread: onTermination 14:474
Main: resutl:100000
Main:结束

 

你可能感兴趣的:(并发编程,#,自定义并发类,#,Fork/Join,并发编程,java)