第二章 - 管理多个线程 - Executors

阅读更多

Executor框架相比于传统的并发系统基础实现具有很多的优势。传统做法是实现一个Runnable接口的类,然后使用该类的对象来直接创建Thread实例。

这种做法有一些问题,特别是当你启动太多线程的时候,你可能降低了整个系统的性能。

 

Executor 框架里的基础组件

  • Executor 接口:这是 executor 框架里的基本组件。它只定义了一个允许程序员给executor发送Runnable对象的方法。
  • ExecutorService 接口:这个接口继承了 Executor 接口,并包含了额外一些方法来增加框架的功能,其中包括:
  1.           运行一些具有返回值的任务:Runnable 接口提供的 run() 方法没有返回值,但是使用 executors,你可以运行一些有返回值的任务
  2.           使用一个方法来执行一系列的任务
  3.           完成 executor 的执行并等待其销毁
  • ThreadPoolExecutor 类:这个类实现了 Executor 和 ExecutorService 两个接口。另外,它提供了一些其它方法来获取 executor 的状态(工作线程数,已经被执行的任务数等等),配置 executor(最小和最大工作线程数)。
  • Executors 类:该类提供了一些实用方法用以创建 Executor 对象以及其它相关的类。

 本章提供了几个例子来解释Executor的使用。例子代码有点长,不过相对还是蛮简单的清晰的。

 

实例1 - K 邻近算法(K-nearest neighbors algorithm)

K邻近算法是一个用于监督分类的简单机器学习算法。算法包含以下主要部分:

  • 训练数据集:该数据集由一系列样本组成,每个样本具由一个或多个属性,还有一个特殊属性来记录每个样本的标签
  • 一个距离矩阵:这个矩阵用来计算需要分类的样本和训练数据集里样本的距离
  • 测试数据集:该数据集用来衡量算法的行为。

当要对一个样本进行归类时,算法计算该样本和训练数据集里所有样本的距离。然后再取距离最小的的 k 个样本,这 k 个样本中,哪个标签数最多,那么这个标签就赋给要归类的那个样本。根据第一章得出的经验,我们从算法的串行版本开始,然后从串行版本演变到并行版本。

 

K邻近算法 - 串行版本

 

public class KnnClassifier {
    private List dataSet;
    private int k;

    public KnnClassifier(List dataSet, int k) {
        this.dataSet = dataSet;
        this.k = k;
    }

    public String classify(Sample example) {
        Distance[] distances = new Distance[dataSet.size()];
        int index = 0;

        // 计算新样本和训练数据集中各样本之间的距离
        for (Sample localExample : dataSet) {
            distances[index] = new Distance();
            distances[index].setIndex(index);
            distances[index].setDistance
                    (EuclideanDistanceCalculator.calculate(localExample,
                            example));
            index++;
        }

        // 对计算得到的距离排序以便获取K个最近距离的样本
        Arrays.sort(distances);

        Map results = new HashMap<>();
        for (int i = 0; i < k; i++) {
            Sample localExample = dataSet.get(distances[i].getIndex());
            String tag = localExample.getTag();
            results.merge(tag, 1, (a, b) -> a + b);
        }

        // 返回最近k个样本总数最多的那个标签
        return Collections.max(results.entrySet(),
                Map.Entry.comparingByValue()).getKey();
    }
}
 

 

 

// 该类用来计算两个样本的距离
public class EuclideanDistanceCalculator {
    public static double calculate (Sample example1, Sample
            example2) {
        double ret=0.0d;
        double[] data1=example1.getExample();
        double[] data2=example2.getExample();
        if (data1.length!=data2.length) {
            throw new IllegalArgumentException ("Vector doesn't have the
                    same length");
        }
        for (int i=0; i   
  

 

K邻近算法 - 细颗粒度的并发版本

如果你分析以上的算法的并行版本,你会发现有两点你可以用并行来实现:

  • 距离的计算:计算输入的样本和训练数据集中各样本的距离的循环中,每一个循环都是独立的,他们之间并不互相依赖。
  • 距离的排序:Java 8 提供了 Arrays 类的 parallelSort() 方法来并发实现数据排序

在细颗粒度并发版本中,我们为每一个计算输入样本和训练数据集中样本的距离创建一个任务。由此可见,所谓的细颗粒度就是我们创建了很多的任务。

 

 

 

public class KnnClassifierParallelIndividual {
    private List dataSet;
    private int k;
    private ThreadPoolExecutor executor;
    private int numThreads;
    private boolean parallelSort;

    public KnnClassifierParallelIndividual(List
                                                   dataSet, int k, 
                                                   int factor, 
                                                   boolean parallelSort) {
        this.dataSet = dataSet;
        this.k = k;

        // 动态获取运行此程序的处理器或核的数量来决定线程池中线程的数量
        numThreads = factor * (Runtime.getRuntime().availableProcessors());
        executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads);
        this.parallelSort = parallelSort;
    }

    /**
     * 因为我们为每个距离计算创建了一个任务,因此主线程需要等待所有任务完成后才能继续,
     * 我们使用 CountDownLatch 这个类来同步所有任务的完成,
     * 我们用任务总数也就是数据集中样本的总数来初始化 CountDownLatch,
     * 每个任务完成后调用 countDown() 方法
     */
    public String classify(Sample example) throws Exception {
        Distance[] distances = new Distance[dataSet.size()];
        CountDownLatch endController = new CountDownLatch(dataSet.size());
        int index = 0;
        for (Sample localExample : dataSet) {
            IndividualDistanceTask task = new
                    IndividualDistanceTask(distances, index, localExample,
                    example, endController);
            executor.execute(task);
            index++;
        }
        endController.await();

        if (parallelSort) {
            Arrays.parallelSort(distances);
        } else {
            Arrays.sort(distances);
        }

        Map results = new HashMap<>();
        for (int i = 0; i < k; i++) {
            Sample localExample = dataSet.get(distances[i].getIndex());
            String tag = localExample.getTag();
            results.merge(tag, 1, (a, b) -> a + b);
        }

        // 返回最近k个样本总数最多的那个标签
        return Collections.max(results.entrySet(),
                Map.Entry.comparingByValue()).getKey();
    }

    public void destroy() {
        executor.shutdown();
    }
}
 

 

 

public class IndividualDistanceTask implements Runnable {
    private Distance[] distances;
    private int index;
    private Sample localExample;
    private Sample example;
    private CountDownLatch endController;

    public IndividualDistanceTask(Distance[] distances, int index,
                                  Sample localExample,
                                  Sample example, CountDownLatch endController) {
        this.distances = distances;
        this.index = index;
        this.localExample = localExample;
        this.example = example;
        this.endController = endController;
    }

    public void run() {
        distances[index] = new Distance();
        distances[index].setIndex(index);
        distances[index].setDistance
                (EuclideanDistanceCalculator.calculate(localExample,
                        example));

        // 任务完成,调用CountDownLatch的countDown()
        endController.countDown();
    }
}
 

 

K 邻近算法 - 粗颗粒度的并发算法版本

细颗粒度版本的问题是创建了太多的任务,粗颗粒度版本中,我们让每一个任务处理数据集的一个子集,这样避免创建太多的任务。

public class KnnClassifierParallelIndividual {
    private List dataSet;
    private int k;
    private ThreadPoolExecutor executor;
    private int numThreads;
    private boolean parallelSort;

    public KnnClassifierParallelIndividual(List
                                                   dataSet, int k, int factor, boolean parallelSort) {
        this.dataSet = dataSet;
        this.k = k;

        // 动态获取运行此程序的处理器或核的数量来决定线程池中线程的数量
        numThreads = factor * (Runtime.getRuntime().availableProcessors());
        executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads);
        this.parallelSort = parallelSort;
    }

    /**
     * 因为我们为每个距离计算创建了一个任务,因此主线程需要等待所有任务完成后才能继续,
     * 我们使用 CountDownLatch 这个类来同步所有任务的完成,
     * 我们用任务总数也就是数据集中样本的总数来初始化 CountDownLatch,每个任务完成后调用 countDown() 方法
     */
    public String classify(Sample example) throws Exception {
        Distance distances[] = new Distance[dataSet.size()];
        CountDownLatch endController = new CountDownLatch(numThreads);
        int length = dataSet.size() / numThreads;
        int startIndex = 0, endIndex = length;

        for (int i = 0; i < numThreads; i++) {
            GroupDistanceTask task = new GroupDistanceTask(distances,
                    startIndex, endIndex, dataSet, example, endController);
            startIndex = endIndex;
            if (i < numThreads - 2) {
                endIndex = endIndex + length;
            } else {
                endIndex = dataSet.size();
            }
            executor.execute(task);
        }

        endController.await();

        if (parallelSort) {
            Arrays.parallelSort(distances);
        } else {
            Arrays.sort(distances);
        }

        Map results = new HashMap<>();
        for (int i = 0; i < k; i++) {
            Sample localExample = dataSet.get(distances[i].getIndex());
            String tag = localExample.getTag();
            results.merge(tag, 1, (a, b) -> a + b);
        }

        // 返回最近k个样本总数最多的那个标签
        return Collections.max(results.entrySet(),
                Map.Entry.comparingByValue()).getKey();
    }

    public void destroy() {
        executor.shutdown();
    }
}

 

public class GroupDistanceTask implements Runnable {
    private Distance[] distances;
    private int startIndex, endIndex;
    private Sample example;
    private List dataSet;
    private CountDownLatch endController;

    public GroupDistanceTask(Distance[] distances, int startIndex,
                             int endIndex, List dataSet, Sample
                                     example, CountDownLatch endController) {
        this.distances = distances;
        this.startIndex = startIndex;
        this.endIndex = endIndex;
        this.example = example;
        this.dataSet = dataSet;
        this.endController = endController;
    }

    public void run() {
        for (int index = startIndex; index < endIndex; index++) {
            Sample localExample = dataSet.get(index);
            distances[index] = new Distance();
            distances[index].setIndex(index);
            distances[index].setDistance(EuclideanDistanceCalculator
                    .calculate(localExample, example));
        }
        endController.countDown();
    }
}

 

实例2 - 并行在客户端/服务器端架构中的应用

在这个例子中:

  • 客户端和服务器端通过sockets通信
  • 客户端向服务器端发送查询以及其它请求,请求以字符串的形式发送,服务器端接收到请求后处理请求并返回处理结果
  • 服务器端接受以下不同请求:
  1. Query:该类型请求格式为 q; codCountry; codIndicator; year
  2. Report: 该类型请求格式为 r; codIndicator
  3. Stop: 该请求格式为 z
  4. 其它请求类型,服务器返回错误信息

跟上述例子一样,我们先从串行版本入手然后过渡到并行版本。

 

客户端/服务器端 - 串行版本

程序包括以下三个主要部分:

  • DAO部分,负责访问数据并且获得客户查询的结果 (我们在代码中忽略DAO的代码,因为它在本例中不是重点)
  • 指令(command)部分,每一种请求类型对应相应的指令(command)类
  • 服务器部分,它接收查询,调用相应的指令并返回查询结果给客户

以下是串行版本指令类部分的代码

//*****************串行版本指令类部分*****************//

// 指令的抽象类
public abstract class Command {
    protected String[] command;
    public Command (String [] command) {
        this.command=command;
    }
    public abstract String execute ();
}

// 对应Query请求的指令类
public class QueryCommand extends Command {
    public QueryCommand(String [] command) {
        super(command);
    }

    public String execute() {
        WDIDAO dao=WDIDAO.getDAO();
        if (command.length==3) {
            return dao.query(command[1], command[2]);
        } else if (command.length==4) {
            try {
                return dao.query(command[1], command[2],
                        Short.parseShort(command[3]));
            } catch (Exception e) {
                return "ERROR;Bad Command";
            }
        } else {
            return "ERROR;Bad Command";
        }
    }
}

//对应Report请求的指令类
public class ReportCommand extends Command {
    public ReportCommand(String [] command) {
        super(command);
    }

    public String execute() {
        WDIDAO dao=WDIDAO.getDAO();
        return dao.report(command[1]);
    }
}

//对应Stop请求的指令类
public class StopCommand extends Command {
    public StopCommand(String [] command) {
        super(command);
    }

    public String execute() {
        return "Server stopped";
    }
}

//此类处理一些服务器不支持的请求
public class ErrorCommand extends Command {
    public ErrorCommand(String [] command) {
        super(command);
    }

    public String execute() {
        return "Unknown command: "+command[0];
    }
}

 

以下是服务器部分的代码

 

public class SerialServer {
    public static void main(String[] args) throws IOException {
        boolean stopServer = false;
        System.out.println("Initialization completed.");
        try (ServerSocket serverSocket = new ServerSocket(Constants.SERIAL_PORT)) {
            // 不断循环,直到stopServer被设置为false
            do {
                try (Socket clientSocket = serverSocket.accept();
                     PrintWriter out = new PrintWriter
                             (clientSocket.getOutputStream(), true);
                     BufferedReader in = new BufferedReader(new
                             InputStreamReader(clientSocket.getInputStream()));) {
                    String line = in.readLine();
                    Command command;
                    String[] commandData = line.split(";");
                    System.out.println("Command: " + commandData[0]);
                    switch (commandData[0]) {
                        case "q":
                            System.out.println("Query");
                            command = new QueryCommand(commandData);
                            break;
                        case "r":
                            System.out.println("Report");
                            command = new ReportCommand(commandData);
                            break;
                        case "z":
                            System.out.println("Stop");
                            command = new StopCommand(commandData);
                            stopServer = true;
                            break;
                        default:
                            System.out.println("Error");
                            command = new ErrorCommand(commandData);
                    }
                    String response = command.execute();
                    System.out.println(response);
                } catch (IOException e) {
                    e.printStackTrace();
                }
            } while (!stopServer);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

 

 

客户端/服务器端 - 并行版本

众所周知以上串行版本存在着严重的性能问题,服务器一次只能处理一个请求,其余的请求需要等待。并行版本中,我们将改为主线程接收请求,然后为每个请求创建一个任务,并交由线程池中的线程执行。

 

以下是并行版本指令类部分的代码,大部分代码和并行版本一样,除了Stop指令类。类名我们改为以"Concurrent"开始

 

//*****************串行版本指令类部分*****************//

// 指令的抽象类
public abstract class Command {
    protected String[] command;
    public Command (String [] command) {
        this.command=command;
    }
    public abstract String execute ();
}

// 对应Query请求的指令类
public class ConcurrentQueryCommand extends Command {
    public ConcurrentQueryCommand(String [] command) {
        super(command);
    }

    public String execute() {
        WDIDAO dao=WDIDAO.getDAO();
        if (command.length==3) {
            return dao.query(command[1], command[2]);
        } else if (command.length==4) {
            try {
                return dao.query(command[1], command[2],
                        Short.parseShort(command[3]));
            } catch (Exception e) {
                return "ERROR;Bad Command";
            }
        } else {
            return "ERROR;Bad Command";
        }
    }
}

//对应Report请求的指令类
public class ConcurrentReportCommand extends Command {
    public ConcurrentReportCommand(String [] command) {
        super(command);
    }

    public String execute() {
        WDIDAO dao=WDIDAO.getDAO();
        return dao.report(command[1]);
    }
}

//对应Stop请求的指令类
public class ConcurrentStopCommand extends Command {
    public ConcurrentStopCommand(String [] command) {
        super(command);
    }

    public String execute() {
        ConcurrentServer.shutdown();
        return "Server stopped";
    }
}

//此类处理一些服务器不支持的请求
public class ConcurrentErrorCommand extends Command {
    public ConcurrentErrorCommand(String [] command) {
        super(command);
    }

    public String execute() {
        return "Unknown command: "+command[0];
    }
}

//并行版本中新增了服务器状态查询指令
public class ConcurrentStatusCommand extends Command {
    public ConcurrentStatusCommand (String[] command) {
        super(command);
    }
    
    @Override
    public String execute() {
        StringBuilder sb=new StringBuilder();
        ThreadPoolExecutor executor=ConcurrentServer.getExecutor();
        sb.append("Server Status;");
        sb.append("Actived Threads: ");
        sb.append(String.valueOf(executor.getActiveCount()));
        sb.append(";");
        sb.append("Maximum Pool Size: ");
        sb.append(String.valueOf(executor.getMaximumPoolSize()));
        sb.append(";");
        sb.append("Core Pool Size: ");
        sb.append(String.valueOf(executor.getCorePoolSize()));
        sb.append(";");
        sb.append("Pool Size: ");
        sb.append(String.valueOf(executor.getPoolSize()));
        sb.append(";");
        sb.append("Largest Pool Size: ");
        sb.append(String.valueOf(executor.getLargestPoolSize()));
        sb.append(";");
        sb.append("Completed Task Count: ");
        sb.append(String.valueOf(executor.getCompletedTaskCount()));
        sb.append(";");
        sb.append("Task Count: ");
        sb.append(String.valueOf(executor.getTaskCount()));
        sb.append(";");
        sb.append("Queue Size: ");
        sb.append(String.valueOf(executor.getQueue().size()));
        sb.append(";");
        return sb.toString();
    }
}

 

 

 以下是服务器部分代码和实现Runnable接口的RequestTask类

public class ConcurrentServer {
    private static ThreadPoolExecutor executor;
    private static ServerSocket serverSocket;
    private static volatile boolean stopped = false;

    public static void main(String[] args) throws InterruptedException {
        serverSocket = null;
        executor = (ThreadPoolExecutor) Executors.newFixedThreadPool
                (Runtime.getRuntime().availableProcessors());
        System.out.println("Initialization completed.");
        serverSocket = new ServerSocket(Constants.CONCURRENT_PORT);

        do {
            try {
                Socket clientSocket = serverSocket.accept();
                RequestTask task = new RequestTask(clientSocket);
                executor.execute(task);
            } catch (IOException e) {
                e.printStackTrace();
            }
        } while (!stopped);

        executor.awaitTermination(1, TimeUnit.DAYS);
        System.out.println("Shutting down cache");
        System.out.println("Cache ok");
        System.out.println("Main server thread ended");
    }

    public static void shutdown() {
        stopped = true;
        System.out.println("Shutting down the server...");
        System.out.println("Shutting down executor");
        executor.shutdown();
        System.out.println("Executor ok");
        System.out.println("Closing socket");
        try {
            serverSocket.close();
            System.out.println("Socket ok");
        } catch (IOException e) {
            e.printStackTrace();
        }
        System.out.println("Shutting down logger");
        System.out.println("Logger ok");
    }

    public static ThreadPoolExecutor getExecutor() {
        return executor;
    }
}

public class RequestTask implements Runnable {
    private Socket clientSocket;
    public RequestTask(Socket clientSocket) {
        this.clientSocket = clientSocket;
    }

    public void run() {
        try (PrintWriter out = new
                PrintWriter(clientSocket.getOutputStream(),
                true);
             BufferedReader in = new BufferedReader(new
                     InputStreamReader(
                     clientSocket.getInputStream()));) {
            String line = in.readLine();

            Command command;
            String[] commandData = line.split(";");
            System.out.println("Command: " + commandData[0]);
            switch (commandData[0]) {
                case "q":
                    System.err.println("Query");
                    command = new ConcurrentQueryCommand(commandData);
                    break;
                case "r":
                    System.err.println("Report");
                    command = new ConcurrentReportCommand(commandData);
                    break;
                case "s":
                    System.err.println("Status");
                    command = new ConcurrentStatusCommand(commandData);
                    break;
                case "z":
                    System.err.println("Stop");
                    command = new ConcurrentStopCommand(commandData);
                    break;
                default:
                    System.err.println("Error");
                    command = new ConcurrentErrorCommand(commandData);
                    break;
            }
            ret = command.execute();

            System.out.println(ret);
            out.println(ret);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            try {
                clientSocket.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}

 

其它有用的方法

Executors 类提供了另外一些方法创建 ThreadPoolExecutor 对象。这些方法包括:

  • newCachedThreadPool():这个方法创建了一个 ThreadPoolExecutor 对象, 这种线程池能够重复利用那些空闲的工作线程,但是如果有需要,会创建新的工作线程。因此它没有最大的工作线程数。
  • newSingleThreadExecutor():这个方法创建了一个只有一个工作线程的 ThreadPoolExecutor 对象。发送给executor的任务保存在队列中,一一被那个工作线程执行。
  • CountDownLatch类提供了以下额外的方法: await(long timeout, TimeUnit unit) 当前线程等待直到内部计数器降为0;如果等待时间超过了参数 timeout,该方法返回 false。getCount() 该方法返回内部计数器的值。

Java中支持两种并行数据结构:

  • 阻塞数据结构 (Blocking data structures):此类数据结构如果不能满足请求的操作,将阻塞访问的线程直到请求能被满足(例如取数据请求但数据结构中无数据)
  • 非阻塞数据结构 (Non-blocking data structures):与阻塞数据结构不同,该数据结构无法满足请求时并不阻塞访问线程

有些数据结构实现了两种行为,有些数据结构则只实现一种行为。通常,阻塞数据结构同时也实现具有非阻塞行为的方法,但是非阻塞线程没有实现阻塞行为的方法。

 

阻塞操作的方法有:

  • put(), putFirst(), putLast():插入数据到数据结构,如果数据结构已满,则阻塞访问线程直到数据结构有可用空间时
  • take(), takeFirst(), takeLast():返回并删除数据结构中的数据。如果数据结构为空,则阻塞访问线程直到数据结构有一个元素

非阻塞操作的方法有:

  • add(), addFirst(), addLast():插入数据到数据结构,如果数据结构已满,则抛出 IllegalStateException 异常
  • remove(), removeFirst(), removeLast():返回并删除数据结构中的数据。如果数据结构为空,则抛出 IllegalStateException 异常
  • element(), getFirst(), getLast():返回但不删除数据。如果数据结构为空,则抛出 IllegalStateException 异常
  • offer(), offerFirst(), offerLast():插入数据到数据结构,如果数据结构已满,则返回 false
  • poll(), pollFirst(), pollLast():返回并删除数据结构中的数据。如果数据结构为空,则返回 null
  • peek(), peekFirst(), peekLast():返回但不删除数据。如果数据结构为空,则返回 null

你可能感兴趣的:(java,多线程)