CyclicBarrier的应用

简介

    我们在多线程应用的一些设计中会碰到一些问题。比如说利用多个线程去分别计算某个问题的部分结果,然后再将结果存储在某个地方。等所有这些线程都结束之后,我们再将这些线程产生的结果合并起来并得到问题的解。这种方法基于这么一个前提,就是所有这些线程可以并行的执行,他们之间不会有互相的干扰。另外,线程产生的结果不会有冲突。比如说,我们可以给线程一个编号限制,某个线程产生的结果放到某个编号的元素里,这样只要他们的编号不重复,就不会有保存结果的冲突。这样,整个执行的过程就可以拆分成多个线程执行,提高了并行度,使得解决问题的效率更高。

 

应用场景分析

    前面我们提到的这种场景确实挺好。而且通过分别的编号也可以解决数据访问冲突的问题。但是,这边还有一个问题就是。我怎么知道所有并发执行的线程都执行结束了呢?因为我需要知道他们所有的都执行结束了我才可以去他们产生的结果里统计最终的运算结果。这个时候,我们就需要用到一个类似于栅栏的机制了。在Java里有这么一个实现,就是CyclicBarrier。该怎么理解这个东西呢?我们可以用一个田径比赛来做类比。

    假设我们有一个跑步的田径比赛。每个线程就相当于是一个参赛选手。这些参赛选手每个人都占据一个跑道,他们只能沿着自己的跑道往前跑,不能窜到别人的跑道上去。在比赛开始后,如果我们想得到所有比赛选手的成绩,很显然,我们需要比赛结束了。怎么才能让比赛结束了呢?最起码是要保证跑在最后的一个选手也到达了终点,只有这个时候我们才能得到所有选手的成绩。这里,CyclicBarrier就好像是那个记分员,每到达一个选手,他会记录一下成绩。但是只有在最后一个选手到达后,他才能把所有的成绩送去做总的统计。

    ok,有了前面这部分分析,我们可以发现。CyclicBarrier就相当于这么一个阻断机制,在前面达到这个点的线程会等在那里,一直到最后一个线程到达后,他才会在这个点让那些线程继续做自己的事情。好像在这个点的时候,所有的线程又站在了同一个起跑线。

 

示例

    现在,根据前面的讨论。我们来举一个实际的例子。假设我们有一个矩阵M*N的矩阵。那里放了若干了数字。假定给定一个数字,我们要统计里面所有等于这个数字的元素个数。那么,除了传统的那个顺序走过每一个元素的办法以外,我们还可以考虑到一个多线程的办法。

    既然线程是M*N的,我们可以将线程按照行分成若干块。比如说有5块,那么从0到M/5 - 1这一行这部分我们分配一个线程来统计。M/5到2*M/5 - 1这部分分配第二个线程来统计,依次类推。我们可以分配5个线程,每个统计其中的一部分。

    统计完了之后呢?我们需要保存结果。既然不希望结果产生冲突。我们可以将结果写到一个数组里。比如说第0行的统计结果就写到数组的索引0,第1行的统计就写到索引1。等所有的线程都跑完之后,我们再启动一个线程来将结果统计出来。

    现在我们就一步步按照这边讨论的来,首先我们来定义一个矩阵,并通过随机的方式生成里面的数字:

import java.util.Random;

public class MatrixMock {
	private int[][] data;
	
	public MatrixMock(int size, int length, int number) {
		int counter = 0;
		data = new int[size][length];
		Random random = new Random();
		for(int i = 0; i < size; i++) {
			for(int j = 0; j < length; j++) {
				data[i][j] = random.nextInt(10);
				if(data[i][j] == number) {
					counter++;
				}
			}
		}
		System.out.printf("Mock: There are %d occurrences of number %d " +
				"in generated data.\n", counter, number);
	}
	
	public int[] getRow(int row) {
		if(row >= 0 && row < data.length) {
			return data[row];
		} else {
			return null;
		}
	}
}

    这里,我们在MatrixMock的构造函数里随机生成0到10之间的数字(不包括10)填充到矩阵中。

    然后,我们再定义一个存放统计结果的类:

public class Results {
	private int[] data;
	
	public Results(int size) {
		data = new int[size];
	}
	
	public void setData(int position, int value) {
		data[position] = value;
	}
	
	public int[] getData() {
		return data;
	}
}

    每个执行的线程通过调用setData可以将自己统计的结果写到数组里面。

    现在,我们再定义一个执行的线程。它根据我们提供的参数作为构造函数来扫描指定的行范围进行统计:

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

public class Searcher implements Runnable {
	private int firstRow;
	private int lastRow;
	private MatrixMock mock;
	private Results results;
	private int number;
	private final CyclicBarrier barrier;
	
	public Searcher(int firstRow, int lastRow, MatrixMock mock, 
			Results results, int number, CyclicBarrier barrier) {
		this.firstRow = firstRow;
		this.lastRow = lastRow;
		this.mock = mock;
		this.results = results;
		this.number = number;
		this.barrier = barrier;
	}
	
	@Override
	public void run() {
		int counter;
		System.out.printf("%s: Processing lines from %d to %d.\n",
				Thread.currentThread().getName(), firstRow, lastRow);
		for(int i = firstRow; i < lastRow; i++) {
			int[] row = mock.getRow(i);
			counter = 0;
			for(int j = 0; j < row.length; j++) {
				if(row[j] == number) {
					counter++;
				}
			}
			results.setData(i, counter);
		}
		System.out.printf("%s: Lines processed.\n",
				Thread.currentThread().getName());
		
		try {
			barrier.await();
		} catch(InterruptedException e) {
			e.printStackTrace();
		} catch(BrokenBarrierException e) {
			e.printStackTrace();
		}
	}
}

    Searcher的代码看起来比较长,实际上并不复杂。在run方法中,它根据构造函数指定的行起始和终止范围去扫描矩阵。得到和期望结果相等的元素则将统计数字加一。再将对应的这一行的统计结果写到results中。

    这里有一个比较重要的地方就是执行完了我们前面提的这几个步骤之后,这里调用了一个barrier.await()。这个await()方法的作用就是让调用这个方法的线程进入等待状态,直到所有线程都调用了这个方法。所以,这个问题最核心的地方就在这里,有了barrier.await()这个收费站,只能等到所有线程到齐才能交够保护费了:)

    前面这些处理完之后,我们就该来统计结果了。我们可以通过启动一个线程的方式来做。这个统计线程的代码就相对很简单:

public class Grouper implements Runnable {
	private Results results;
	
	public Grouper(Results results) {
		this.results = results;
	}
	
	@Override
	public void run() {
		int finalResult = 0;
		System.out.printf("Grouper: Processing results...\n");
		int[] data = results.getData();
		for(int number : data) {
			finalResult += number;
		}
		System.out.printf("Grouper: Total result: %d.\n", finalResult);
	}
}

    就是一个遍历所有结果数组,将结果相加。

    现在,所有的东西都已经具备了,就差把他们都整合起来的东风:

import java.util.concurrent.CyclicBarrier;

public class Main {

	public static void main(String[] args) {
		final int ROWS = 10000;
		final int NUMBERS = 1000;
		final int SEARCH = 5;
		final int PARTICIPANTS = 5;
		final int LINE_PARTICIPANT = 2000;
		
		MatrixMock mock = new MatrixMock(ROWS, NUMBERS, SEARCH);
		Results results = new Results(ROWS);
		Grouper grouper = new Grouper(results);
		
		CyclicBarrier barrier = new CyclicBarrier(PARTICIPANTS, grouper);
		
		Searcher[] searchers = new Searcher[PARTICIPANTS];
		for(int i = 0; i < PARTICIPANTS; i++) {
			searchers[i] = new Searcher(i * LINE_PARTICIPANT, (i * LINE_PARTICIPANT) + LINE_PARTICIPANT,
					mock, results, 5, barrier);
			Thread thread = new Thread(searchers[i]);
			thread.start();
		}
		System.out.printf("Main: The main thread has finished.\n");
	}
}

     这里定义了一个10000 × 1000的矩阵,并将它划分成5个区域。在for循环的地方启动了5个线程。每个在一个指定的区域开始工作。有一个需要特别注意的地方是创建CyclicBarrier对象的构造函数:CyclicBarrier barrier = new CyclicBarrier(PARTICIPANTS, grouper);

    我们将统计线程作为一个参数传入它的构造函数。在所有等待线程都跑到终点之后,这个grouper线程就会被启动起来。

    如果运行这个程序,我们会发现有类似如下的显示结果:

Mock: There are 999916 occurrences of number 5 in generated data.
Thread-0: Processing lines from 0 to 2000.
Thread-4: Processing lines from 8000 to 10000.
Main: The main thread has finished.
Thread-3: Processing lines from 6000 to 8000.
Thread-2: Processing lines from 4000 to 6000.
Thread-1: Processing lines from 2000 to 4000.
Thread-4: Lines processed.
Thread-3: Lines processed.
Thread-0: Lines processed.
Thread-2: Lines processed.
Thread-1: Lines processed.
Grouper: Processing results...
Grouper: Total result: 999916.

 

总结

    CyclicBarrier是一个比较有意思的线程阻断机制。它可以让指定的一组线程都停在某个点上。对于一些线程并行执行和结果统计的方式,它是一个可选项。当然,后面结合一些高级的手法,还有更好的方式。    

参考材料

Java 7 concurrency cookbook

你可能感兴趣的:(java,多线程,thread)