Rust for cpp dev - 线程池

在 web server 项目中,我们只用了单线程,但是现实中,都是用多线程/多进程等方式来提高并发性能。这一章,我们使用线程池来优化这个 web server 项目。

20.2 使用线程池优化

使用线程池的好处是:

  • 充分利用多核处理器优势,并行处理请求
  • 限制最大线程数,避免 DoS (Denial of Service)攻击
  • 避免反复创建线程的开销

线程池原理

一般的线程池主要由两个部分组成:

  1. 多个 Worker 线程
  2. 一个任务队列

每个 Worker 线程会尝试从任务队列获取任务并执行,若任务队列为空则阻塞。因此,任务队列是一个 multiple consumer 队列,至于 producer,在 web server 的应用场景中只有一个,即接收连接的线程。

线程池应该主要提供以下几个接口:

  • ThreadPool::new:在构造时,分配指定数量的线程作为 Worker
  • ThreadPool::execute:往队列中添加任务
  • ThreadPool::drop:在析构时,保证所有的任务执行完毕

Rust 简易线程池实现

在 Rust 中,std::sync::mpsc 提供了一个 multiple producer single consumer 的队列。我们需要手动为 consumer 端做同步来实现一个支持 multiple consumer 的队列。

pub fn channel() -> (Sender, Receiver)

通过 std::sync::mpsc::channel 方法生成 Sender 和 Receiver,ThreadPool 持有队列的 Sender,在 execute 时往里面添加任务。而所有的 Worker 共同持有队列的 Receiver,并不断尝试从 Receiver 中获取任务。

我们对“任务”和“共享Receiver”的定义如下:

type Job = Box;

// all workers get task from it
type SharedTaskReceiver = Arc>>;

ThreadPool

基于这个思想,我们很容易写出 ThreadPool 的实现:

pub struct ThreadPool {
    workers: Vec,
    task_sender: mpsc::Sender,
}


impl ThreadPool {
    pub fn new(size: usize) -> Result {
        if size == 0 {
            return Err("A thread pool with 0 thread is not allowed");
        }

        let (sender, receiver) = mpsc::channel();
        let shared_receiver = Arc::new(Mutex::new(receiver));

        let mut workers = vec![];
        for i in 0..size {
            workers.push(Worker::new(i, Arc::clone(&shared_receiver)));
        }

        Ok(ThreadPool {
            workers: workers,
            task_sender: sender,
        })
    }

    pub fn execute(&self, func: F) {
        self.task_sender.send(Box::new(func)).unwrap();
    }
}

非常值得注意的是,我们使用了Arc>类型来表示我们需要在所有 Worker 中“共享”这个 Receiver。

Worker

对于 Worker,我们让它在 loop 中尝试从 Receiver 获取任务。

struct Worker {
    id: usize,
    handle: std::thread::JoinHandle<()>, // task return type is empty '()'
}

impl Worker {
    pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
        let handle = std::thread::spawn(
            move || {
                loop {
                    let task = task_receiver.lock().unwrap().recv();
                    if task.is_ok() {
                        let f = task.unwrap();
                        println!("Worker {} got a job, executing...", id);
                        f();
                    }
                }
            });
        Worker{id: id, handle: handle}
    }
}

值得注意的是,lockrecv 方法都是阻塞的。例如,线程 A 最先拿到锁,然而没有任务,就会阻塞在 recv 方法上,也不会释放锁。其他线程则会在 lock 方法阻塞。

等待 Worker 执行任务

我们将为 ThreadPool 实现 Drop trait 来 join 所有的线程,保证队列中的任务完成。

impl Drop for ThreadPool {
    fn drop(&mut self) {
        for worker in self.workers.iter_mut() {
            println!("Shutting down worker {}", worker.id);
            worker.handle.join().unwrap();
        }
    }
}

以上的简单实现会报错:

error[E0507]: cannot move out of `worker.handle` which is behind a mutable reference

这是因为,join 会拿走 JoinHandle 的 ownership,而 drop 的参数是 &mut self,仅仅是一个引用。

那我们怎样才能从一个引用得到其中内容的 ownership 呢?

答案是将 Worker own 的 JoinHandle 改成 Option,这样,就可以使用 Option::take 方法来拿走里面的 JoinHandle 的 ownership。但是从 Worker 看来,它仍然 own 和一个 Option,只是里面内容变成了 None 而已。

这样做是合理的,并不是一个 workaround,因为在析构时的 for 循环中,首先被 joinWorker 实际上已经不 own 一个 JoinHandle 了,而其他未被 joinWorker 还 own 一个 JoinHandle,因此它确实是一个 Option。本质上,Rust 的 ownership 使我们的程序更加严谨了。

struct Worker {
    id: usize,
    // task return type is empty '()'
    // when drop, some threads which joined first will not own JoinHandle
    handle: Option>,
}

drop 时,使用 takeOption 替换成 None,并拿走 ownership:

impl Drop for ThreadPool {
    fn drop(&mut self) {
        for worker in self.workers.iter_mut() {
            println!("Shutting down worker {}", worker.id);
            if let Some(handle) = worker.handle.take() {
                handle.join().unwrap();
            }
        }
    }
}

做了这些改动,再尝试运行,发现程序不是直接退出了,但是 hang 在那里,也不是我们期望的行为。这是因为 Worker 运行一个 loop,无法结束,所以 join了就会永远等待下去。

停止 Worker

我们将修改程序使 ThreadPool 接收停止信号以退出 loop。 首先,队列接收的不止是 Job 了,还可能接收关闭信号 Terminate

enum Message {
    Task(Job),
    Terminate,
}

此外,当 Worker 处理时,对于 Terminate 信号,需要打破无限循环 loop

impl Worker {
    pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
        let handle = std::thread::spawn(move || loop {
            let message = task_receiver.lock().unwrap().recv().unwrap();
            match message {
                Message::Task(job) => {
                    println!("Worker {} got a job, executing...", id);
                    job();
                }
                Message::Terminate => {
                    println!("Worker {} was told to terminate.", id);
                    break;
                }
            }
        });
        Worker {
            id: id,
            handle: Some(handle),
        }
    }
}

那么,什么时候发送 Terminate 信号呢?在析构的时候。

impl Drop for ThreadPool {
    fn drop(&mut self) {
        println!(
            "Sending terminate message to {} workers",
            self.workers.len()
        );
        for _ in self.workers.iter() {
            self.task_sender.send(Message::Terminate).unwrap();
        }

        for worker in self.workers.iter_mut() {
            println!("Shutting down worker {}", worker.id);
            if let Some(handle) = worker.handle.take() {
                handle.join().unwrap();
            }
        }
    }
}

这里首先对 n 个 Worker 发送了 n 个 Terminate,由于接收到了 Terminate 的 Worker 不会再处理消息,因此每个 Worker 恰好消耗一个 Terminate 消息。

同时我们还注意到,我们用了两个 for 循环来分别发送 Terminate 和 join thread,这是为了避免死锁。假设一个简单的情况,只有 2 个 Worker A 和 B。如果是一个 for 循环既发 Terminate 又 join,则可能会出现下面场景:

  • 发送了一个 Terminate,A 还在执行 job,B 收到 Terminate 并退出
  • 尝试 join A,但是由于 A 还未收到 Terminate,所以一直等待
  • for 循环卡住

至此,一个基本的线程池已经实现。全部代码见附录。

附录

以下是全部代码:

// lib.rs
use std::sync::{mpsc, Arc, Mutex};

type Job = Box;

enum Message {
    Task(Job),
    Terminate,
}

// all workers get task from it
type SharedTaskReceiver = Arc>>;

struct Worker {
    id: usize,
    // task return type is empty '()'
    // when drop, some threads which joined first will not own JoinHandle
    handle: Option>,
}

impl Worker {
    pub fn new(id: usize, task_receiver: SharedTaskReceiver) -> Worker {
        let handle = std::thread::spawn(move || loop {
            let message = task_receiver.lock().unwrap().recv().unwrap();
            match message {
                Message::Task(job) => {
                    println!("Worker {} got a job, executing...", id);
                    job();
                }
                Message::Terminate => {
                    println!("Worker {} was told to terminate.", id);
                    break;
                }
            }
        });
        Worker {
            id: id,
            handle: Some(handle),
        }
    }
}

pub struct ThreadPool {
    workers: Vec,
    task_sender: mpsc::Sender,
}

impl ThreadPool {
    pub fn new(size: usize) -> Result {
        if size == 0 {
            return Err("A thread pool with 0 thread is not allowed");
        }

        let (sender, receiver) = mpsc::channel();
        let shared_receiver = Arc::new(Mutex::new(receiver));

        let mut workers = vec![];
        for i in 0..size {
            workers.push(Worker::new(i, Arc::clone(&shared_receiver)));
        }

        Ok(ThreadPool {
            workers: workers,
            task_sender: sender,
        })
    }

    pub fn execute(&self, func: F) {
        let new_job = Message::Task(Box::new(func));
        self.task_sender.send(new_job).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        println!(
            "Sending terminate message to {} workers",
            self.workers.len()
        );
        for _ in self.workers.iter() {
            self.task_sender.send(Message::Terminate).unwrap();
        }

        for worker in self.workers.iter_mut() {
            println!("Shutting down worker {}", worker.id);
            if let Some(handle) = worker.handle.take() {
                handle.join().unwrap();
            }
        }
    }
}
// main.rs
use std::fs;
use std::io::prelude::{Read, Write};
use std::net::{TcpListener, TcpStream};

fn handle_client(mut stream: TcpStream) {
    let mut buffer = [0; 1024];
    stream.read(&mut buffer).unwrap();

    // GET request prefix
    let get = b"GET / HTTP/1.1\r\n";

    let mut status_line = "HTTP/1.1 200 OK";
    let mut filename = "hello.html";

    if buffer.starts_with(get) == false {
        status_line = "HTTP/1.1 404 NOT FOUND";
        filename = "404.html";
    }

    let contents = fs::read_to_string(filename).unwrap();
    let response = format!(
        "{}\r\nContent-Length: {}\r\n\r\n{}",
        status_line,
        contents.len(),
        contents
    );
    stream.write(response.as_bytes()).unwrap();
    stream.flush().unwrap();
}

fn main() {
    let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
    let pool = web_server::ThreadPool::new(4).unwrap();

    for stream in listener.incoming() {
        match stream {
            Ok(stream) => pool.execute(|| {
                handle_client(stream);
            }),
            Err(e) => println!("connection failed: {}", e),
        }
    }
}

打开 127.0.0.1:7878,不断刷新可以看到不同的线程在处理请求:

Worker 0 got a job, executing...
Worker 1 got a job, executing...
Worker 2 got a job, executing...
Worker 3 got a job, executing...
Worker 2 got a job, executing...
Worker 0 got a job, executing...
Worker 1 got a job, executing...
Worker 3 got a job, executing...
Worker 2 got a job, executing...
Worker 0 got a job, executing...
Worker 1 got a job, executing...
Worker 3 got a job, executing...

单元测试

以下是一个简单的单元测试,希望测试两个方面:

  1. 任务是并行的,这个可以用测试的运行时间判断
  2. 所有任务执行完毕,这个可以通过 assert_eq! 判断
#[cfg(test)]
mod tests {
    use std::sync::{Arc, Mutex};
    use std::time::Duration;

    #[test]
    fn slow_tasks_in_parallel() {
        let mut tasks = vec![];

        let counter = Arc::new(Mutex::new(0));
        let total = 200;
        for i in 0..total {
            let counter1 = Arc::clone(&counter);
            tasks.push(move || {
                std::thread::sleep(Duration::from_millis(i));
                let mut num = counter1.lock().unwrap();
                *num += 1;
            });
        }

        let pool = crate::ThreadPool::new(10).unwrap();

        for task in tasks {
            pool.execute(task);
        }

        // guarantee to call ThreadPool::drop before check
        std::mem::drop(pool);

        assert_eq!(*counter.lock().unwrap(), total);
    }
}

非常值得注意的是我们手动调用了

std::mem::drop(pool);

这样可以保证在检查之前所有任务执行完毕。assert_eq! 失败会导致 panic,在多线程中,panic 非常难 debug,报错信息为:

running 1 test
thread panicked while panicking. aborting.
error: test failed, to rerun pass '--lib'

Caused by:
  process didn't exit successfully: `~/Project/rust/web_server/target/debug/deps/web_server-0af50e188492a41d` (signal: 4, SIGILL: illegal instruction)

浪费了不少时间才找到这个bug。

最终运行时间是 2.14s,根据我们的代码,10 个线程一共睡眠了 0+1+..+199 = 19900ms = 19.9s,说明成功并行。

你可能感兴趣的:(Rust for cpp dev - 线程池)