rust的并发以及kv server网络处理和网络安全部分

理解并发和并行
Golang 的创始人之一,对此有很精辟很直观的解释:并发是一种同时处理很多事情的能力,并行是一种同时执行很多事情的手段。
我们把要做的事情放在多个线程中,或者多个异步任务中处理,这是并发的能力。在多核多 CPU 的机器上同时运行这些线程或者异步任务,是并行的手段。可以说,并发是为并行赋能。当我们具备了并发的能力,并行就是水到渠成的事情。

在处理并发的过程中,难点并不在于如何创建多个线程来分配工作,在于如何在这些并发的任务中进行同步。我们来看并发状态下几种常见的工作模式:自由竞争模式、map/reduce 模式、DAG 模式:
rust的并发以及kv server网络处理和网络安全部分_第1张图片
map/reduce 模式,把工作打散,按照相同的处理完成后,再按照一定的顺序将结果组织起来;DAG 模式,把工作切成不相交的、有依赖关系的子任务,然后按依赖关系并发执行。(这里可以联系C++并发编程里面,并发是怎么处理的)
在这些并发模式背后,都有哪些并发原语可以为我们所用呢,这两讲会重点讲解和深入五个概念 Atomic、Mutex、Condvar、Channel 和 Actor model。今天先讲前两个 Atomic 和 Mutex。

Atomic 是所有并发原语的基础,它为并发任务的同步奠定了坚实的基础。背后是CAS原理:
最基础的保证是:可以通过一条指令读取某个内存地址,判断其值是否等于某个前置值,如果相等,将其修改为新的值。这就是 Compare-and-swap 操作,简称CAS。它是操作系统的几乎所有并发原语的基石,使得我们能实现一个可以正常工作的锁。compare_exchange 是 Rust 提供的 CAS 操作,它会被编译成 CPU 的对应 CAS 指令。
但对于和编译器 /CPU 自动优化相关的 3 和 4,我们还需要一些额外处理。这就是这个函数里额外的两个和 Ordering 有关的奇怪参数。这个也可以联系C++中的语句排序。


pub enum Ordering {
    Relaxed,
    Release,
    Acquire,
    AcqRel,
    SeqCst,
}

我个人用的最多的是做各种 lock-free 的数据结构。比如,需要一个全局的 ID 生成器。当然可以使用 UUID 这样的模块来生成唯一的 ID,但如果我们同时需要这个 ID 是有序的,那么 AtomicUsize 就是最好的选择。

Mutex(互斥锁和自旋锁)还可以联系互斥锁和条件变量实现同步的机制
Atomic 虽然可以处理自由竞争模式下加锁的需求**,但毕竟用起来不那么方便,我们需要更高层的并发原语**,来保证软件系统控制多个线程对同一个共享资源的访问,使得每个线程在访问共享资源的时候,可以独占或者说互斥访问(mutual exclusive access)。

SpinLock,顾名思义,就是线程通过 CPU 空转(spin,就像前面的 while loop)忙等(busy wait),来等待某个临界区可用的一种锁。然而,这种通过 SpinLock 做互斥的实现方式有使用场景的限制:如果受保护的临界区太大,那么整体的性能会急剧下降, CPU 忙等,浪费资源还不干实事,不适合作为一种通用的处理方法。
互斥锁:而使用 Mutex lock,线程在等待锁的时候会被调度出去,等锁可用时再被调度回来。
听上去 SpinLock 似乎效率很低,其实不是,这要具体看锁的临界区大小。如果临界区要执行的代码很少,那么和 Mutex lock 带来的上下文切换(context switch)相比,SpinLock 是值得的。在 Linux Kernel 中,很多时候我们只能使用 SpinLock。

atomic / Mutex 解决了自由竞争模式下并发任务的同步问题,也能够很好地解决 map/reduce 模式下的同步问题,因为此时同步只发生在 map 和 reduce 两个阶段。
然而,它们没有解决一个更高层次的问题,也就是 DAG 模式:如果这种访问需要按照一定顺序进行或者前后有依赖关系,该怎么做?
这个问题的典型场景是生产者 - 消费者模式:生产者生产出来内容后,需要有机制通知消费者可以消费。比如 socket 上有数据了,通知处理线程来处理数据,处理完成之后,再通知 socket 收发的线程发送数据。

Condvar这应该和C++的条件变量差不多。注意对比
所以,操作系统还提供了 Condvar。Condvar 有两种状态:等待(wait):线程在队列中等待,直到满足某个条件。通知(notify):当 condvar 的条件满足时,当前线程通知其他等待的线程可以被唤醒。通知可以是单个通知,也可以是多个通知,甚至广播(通知所有人)。在实践中,Condvar 往往和 Mutex 一起使用:Mutex 用于保证条件在读写时互斥,Condvar 用于控制线程的等待和唤醒。我们

Channel
但是用 Mutex 和 Condvar 来处理复杂的 DAG 并发模式会比较吃力。所以,Rust 还提供了各种各样的 Channel 用于处理并发任务之间的通讯。Channel 把锁封装在了队列写入和读取的小块区域内,然后把读者和写者完全分离,使得读者读取数据和写者写入数据,对开发者而言,除了潜在的上下文切换外,完全和锁无关,就像访问一个本地队列一样。
相对于 Mutex,Channel 的抽象程度最高,接口最为直观,使用起来的心理负担也没那么大。使用 Mutex 时,你需要很小心地避免死锁,控制临界区的大小,防止一切可能发生的意外。

Channel 在具体实现的时候,根据不同的使用场景,会选择不同的工具。Rust 提供了以下四种 Channel:
oneshot:这可能是最简单的 Channel,写者就只发一次数据,而读者也只读一次。这种一次性的、多个线程间的同步可以用 oneshot channel 完成。由于 oneshot 特殊的用途,实现的时候可以直接用 atomic swap 来完成。

bounded:bounded channel 有一个队列,但队列有上限。一旦队列被写满了,写者也需要被挂起等待。当阻塞发生后,读者一旦读取数据,channel 内部就会使用 Condvar 的 notify_one 通知写者,唤醒某个写者使其能够继续写入。

unbounded:queue 没有上限,如果写满了,就自动扩容。我们知道,Rust 的很多数据结构如 Vec 、VecDeque 都是自动扩容的。unbounded 和 bounded 相比,除了不阻塞写者,其它实现都很类似。

所有这些 channel 类型,同步和异步的实现思路大同小异,主要的区别在于挂起 / 唤醒的对象。在同步的世界里,挂起 / 唤醒的对象是线程;而异步的世界里,是粒度很小的 task。

阶段实操(4):构建一个简单的KV server-网络处理
(关于protobuf解析也可以联系C++那个项目)
之前一直在使用一个神秘的 async-prost 库,我们神奇地完成了 TCP frame 的封包和解包。主要的思路就是在序列化数据的时候,添加一个头部来提供 frame 的长度,反序列化的时候,先读出头部,获得长度,再读取相应的数据。
今天我们的挑战就是,在上一次完成的 KV server 的基础上,来试着不依赖 async-prost自己处理封包和解包的逻辑。如果你掌握了这个能力,配合 protobuf,就可以设计出任何可以承载实际业务的协议了

protobuf 帮我们解决了协议消息如何定义的问题,然而一个消息和另一个消息之间如何区分,是个伤脑筋的事情。我们需要定义合适的分隔符。分隔符 + 消息数据,就是一个 Frame
(很多基于 TCP 的协议会使用 \r\n 做分隔符,比如 FTP;也有使用消息长度做分隔符的,比如 gRPC;还有混用两者的,比如 Redis 的 RESP;更复杂的如 HTTP,header 之间使用 \r\n 分隔,header / body 之间使用 \r\n\r\n,header 中会提供 body 的长度等等。“\r\n” 这样的分隔符,适合协议报文是 ASCII 数据;而通过长度进行分隔,适合协议报文是二进制数据。我们的 KV Server 承载的 protobuf 是二进制,所以就在 payload 之前放一个长度,来作为 frame 的分隔。)
tokio 有个 tokio-util 库,已经帮我们处理了和 frame 相关的封包解包的主要需求,包括 LinesDelimited(处理 \r\n 分隔符)和 LengthDelimited(处理长度分隔符)
let mut stream = Framed::new(stream, LengthDelimitedCodec::new());

(为什么要自己设计?因为现实的需求是多变的,不仅仅只有分割符确定长度这一点,比如还可以自定义是否需要压缩?是否需要其他特殊处理?库代码使用有限,因为它的接口提供的功能是固定的
为了更贴近实际,我们把 4 字节长度的最高位拿出来作为是否压缩的信号,如果设置了,代表后续的 payload 是 gzip 压缩过的 protobuf,否则直接是 protobuf:
rust的并发以及kv server网络处理和网络安全部分_第2张图片
按照惯例,还是先来定义处理这个逻辑的 trait:

pub trait FrameCoder
where
    Self: Message + Sized + Default,
{
    /// 把一个 Message encode 成一个 frame
    fn encode_frame(&self, buf: &mut BytesMut) -> Result<(), KvError>;
    /// 把一个完整的 frame decode 成一个 Message
    fn decode_frame(buf: &mut BytesMut) -> Result<Self, KvError>;
}

实现trait


use std::io::{Read, Write};

use crate::{CommandRequest, CommandResponse, KvError};
use bytes::{Buf, BufMut, BytesMut};
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use prost::Message;
use tokio::io::{AsyncRead, AsyncReadExt};
use tracing::debug;

/// 长度整个占用 4 个字节
pub const LEN_LEN: usize = 4;
/// 长度占 31 bit,所以最大的 frame 是 2G
const MAX_FRAME: usize = 2 * 1024 * 1024 * 1024;
/// 如果 payload 超过了 1436 字节,就做压缩
const COMPRESSION_LIMIT: usize = 1436;
/// 代表压缩的 bit(整个长度 4 字节的最高位)
const COMPRESSION_BIT: usize = 1 << 31;

/// 处理 Frame 的 encode/decode
pub trait FrameCoder
where
    Self: Message + Sized + Default,
{
    /// 把一个 Message encode 成一个 frame
    fn encode_frame(&self, buf: &mut BytesMut) -> Result<(), KvError> {
        let size = self.encoded_len();

        if size >= MAX_FRAME {
            return Err(KvError::FrameError);
        }

        // 我们先写入长度,如果需要压缩,再重写压缩后的长度
        buf.put_u32(size as _);

        if size > COMPRESSION_LIMIT {
            let mut buf1 = Vec::with_capacity(size);
            self.encode(&mut buf1)?;

            // BytesMut 支持逻辑上的 split(之后还能 unsplit)
            // 所以我们先把长度这 4 字节拿走,清除
            let payload = buf.split_off(LEN_LEN);
            buf.clear();

            // 处理 gzip 压缩,具体可以参考 flate2 文档
            let mut encoder = GzEncoder::new(payload.writer(), Compression::default());
            encoder.write_all(&buf1[..])?;

            // 压缩完成后,从 gzip encoder 中把 BytesMut 再拿回来
            let payload = encoder.finish()?.into_inner();
            debug!("Encode a frame: size {}({})", size, payload.len());

            // 写入压缩后的长度
            buf.put_u32((payload.len() | COMPRESSION_BIT) as _);

            // 把 BytesMut 再合并回来
            buf.unsplit(payload);

            Ok(())
        } else {
            self.encode(buf)?;
            Ok(())
        }
    }

    /// 把一个完整的 frame decode 成一个 Message
    fn decode_frame(buf: &mut BytesMut) -> Result<Self, KvError> {
        // 先取 4 字节,从中拿出长度和 compression bit
        let header = buf.get_u32() as usize;
        let (len, compressed) = decode_header(header);
        debug!("Got a frame: msg len {}, compressed {}", len, compressed);

        if compressed {
            // 解压缩
            let mut decoder = GzDecoder::new(&buf[..len]);
            let mut buf1 = Vec::with_capacity(len * 2);
            decoder.read_to_end(&mut buf1)?;
            buf.advance(len);

            // decode 成相应的消息
            Ok(Self::decode(&buf1[..buf1.len()])?)
        } else {
            let msg = Self::decode(&buf[..len])?;
            buf.advance(len);
            Ok(msg)
        }
    }
}

impl FrameCoder for CommandRequest {}
impl FrameCoder for CommandResponse {}

fn decode_header(header: usize) -> (usize, bool) {
    let len = header & !COMPRESSION_BIT;
    let compressed = header & COMPRESSION_BIT == COMPRESSION_BIT;
    (len, compressed)
}

如果你有些疑惑为什么 COMPRESSION_LIMIT 设成 1436?
这是因为以太网的 MTU 是 1500,除去 IP 头 20 字节、TCP 头 20 字节,还剩 1460;一般 TCP 包会包含一些 Option(比如 timestamp),IP 包也可能包含,所以我们预留 20 字节;再减去 4 字节的长度,就是 1436,不用分片的最大消息长度。如果大于这个,很可能会导致分片,我们就干脆压缩一下。

目前,这个代码没有触及任何和 socket IO 相关的内容,只是纯逻辑,接下来我们要将它和我们用于处理服务器客户端的 TcpStream 联系起来。中间还有一些处理让stream可以处理frame,先不说了。
主要是让stream读取完整的frame,涉及到一些库函数,所以不详细说了。
stream.read_exact(&mut buf[LEN_LEN…]).await?;

接下来要构思一下,服务端和客户端该如何封装。
在服务器端,用process进行封装


#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();
    let addr = "127.0.0.1:9527";
    let service: Service = ServiceInner::new(MemTable::new()).into();
    let listener = TcpListener::bind(addr).await?;
    info!("Start listening on {}", addr);
    loop {
        let (stream, addr) = listener.accept().await?;
        info!("Client {:?} connected", addr);
        let stream = ProstServerStream::new(stream, service.clone());
        tokio::spawn(async move { stream.process().await });
    }
}

这个 process() 方法,实际上就是对 examples/server.rs 中 tokio::spawn 里的 while loop 的封装:


while let Some(Ok(cmd)) = stream.next().await {
    info!("Got a new command: {:?}", cmd);
    let res = svc.execute(cmd);
    stream.send(res).await.unwrap();
}

对客户端,我们也希望可以直接 execute() 一个命令,就能得到结果:


#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();

    let addr = "127.0.0.1:9527";
    // 连接服务器
    let stream = TcpStream::connect(addr).await?;

    let mut client = ProstClientStream::new(stream);

    // 生成一个 HSET 命令
    let cmd = CommandRequest::new_hset("table1", "hello", "world".to_string().into());

    // 发送 HSET 命令
    let data = client.execute(cmd).await?;
    info!("Got response {:?}", data);

    Ok(())
}

这个 execute(),实际上就是对 examples/client.rs 中发送和接收代码的封装:


client.send(cmd).await?;
if let Some(Ok(data)) = client.next().await {
    info!("Got response {:?}", data);
}

好,先看服务器处理一个 TcpStream 的数据结构,它需要包含 TcpStream,还有我们之前创建的用于处理客户端命令的 Service。所以,让服务器处理 TcpStream 的结构包含这两部分:


pub struct ProstServerStream<S> {
    inner: S,
    service: Service,
}

而客户端处理 TcpStream 的结构就只需要包含 TcpStream:


pub struct ProstClientStream<S> {
    inner: S,
}

这里,依旧使用了泛型参数 S。**未来,如果要支持 WebSocket,或者在 TCP 之上支持 TLS,它都可以让我们无需改变这一层的代码。**这里也体现了泛型参数的好处,和前面store trait一样的。

接下来就是具体实现process和execute


mod frame;
use bytes::BytesMut;
pub use frame::{read_frame, FrameCoder};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::info;

use crate::{CommandRequest, CommandResponse, KvError, Service};

/// 处理服务器端的某个 accept 下来的 socket 的读写
pub struct ProstServerStream<S> {
    inner: S,
    service: Service,
}

/// 处理客户端 socket 的读写
pub struct ProstClientStream<S> {
    inner: S,
}

impl<S> ProstServerStream<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    pub fn new(stream: S, service: Service) -> Self {
        Self {
            inner: stream,
            service,
        }
    }

    pub async fn process(mut self) -> Result<(), KvError> {
        while let Ok(cmd) = self.recv().await {
            info!("Got a new command: {:?}", cmd);
            let res = self.service.execute(cmd);
            self.send(res).await?;
        }
        // info!("Client {:?} disconnected", self.addr);
        Ok(())
    }

    async fn send(&mut self, msg: CommandResponse) -> Result<(), KvError> {
        let mut buf = BytesMut::new();
        msg.encode_frame(&mut buf)?;
        let encoded = buf.freeze();
        self.inner.write_all(&encoded[..]).await?;
        Ok(())
    }

    async fn recv(&mut self) -> Result<CommandRequest, KvError> {
        let mut buf = BytesMut::new();
        let stream = &mut self.inner;
        read_frame(stream, &mut buf).await?;
        CommandRequest::decode_frame(&mut buf)
    }
}

impl<S> ProstClientStream<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
{
    pub fn new(stream: S) -> Self {
        Self { inner: stream }
    }

    pub async fn execute(&mut self, cmd: CommandRequest) -> Result<CommandResponse, KvError> {
        self.send(cmd).await?;
        Ok(self.recv().await?)
    }

    async fn send(&mut self, msg: CommandRequest) -> Result<(), KvError> {
        let mut buf = BytesMut::new();
        msg.encode_frame(&mut buf)?;
        let encoded = buf.freeze();
        self.inner.write_all(&encoded[..]).await?;
        Ok(())
    }

    async fn recv(&mut self) -> Result<CommandResponse, KvError> {
        let mut buf = BytesMut::new();
        let stream = &mut self.inner;
        read_frame(stream, &mut buf).await?;
        CommandResponse::decode_frame(&mut buf)
    }
}

写完之后,发现服务端和客户端代码更简洁了,用process代替服务端处理流程,execute代替客户端执执行命令流程。而且还用到了自定义的frame处理方法。这就是这一节的改进。自定义的stream中采用了泛型参数S,这可以让以后在添加新的协议类型的时候不需要修改代码。

阶段实操(5):构建一个简单的KV server-网络安全
那么,当我们的应用架构在 TCP 上时,如何使用 TLS 来保证客户端和服务器间的安全性呢?
想要使用 TLS,我们首先需要 x509 证书。TLS 需要 x509 证书让客户端验证服务器是否是一个受信的服务器,甚至服务器验证客户端,确认对方是一个受信的客户端。
为了测试方便,我们要有能力生成自己的 CA 证书、服务端证书,甚至客户端证书。证书生成的细节今天就不详细介绍了,我之前做了一个叫 certify 的库,可以用来生成各种证书。我们可以在 Cargo.toml 里加入这个库:

[dev-dependencies]

certify = “0.3”

然后在根目录下创建 fixtures 目录存放证书,再创建 examples/gen_cert.rs 文件,添入如下代码:


use anyhow::Result;
use certify::{generate_ca, generate_cert, load_ca, CertType, CA};
use tokio::fs;

struct CertPem {
    cert_type: CertType,
    cert: String,
    key: String,
}

#[tokio::main]
async fn main() -> Result<()> {
    let pem = create_ca()?;
    gen_files(&pem).await?;
    let ca = load_ca(&pem.cert, &pem.key)?;
    let pem = create_cert(&ca, &["kvserver.acme.inc"], "Acme KV server", false)?;
    gen_files(&pem).await?;
    let pem = create_cert(&ca, &[], "awesome-device-id", true)?;
    gen_files(&pem).await?;
    Ok(())
}

fn create_ca() -> Result<CertPem> {
    let (cert, key) = generate_ca(
        &["acme.inc"],
        "CN",
        "Acme Inc.",
        "Acme CA",
        None,
        Some(10 * 365),
    )?;
    Ok(CertPem {
        cert_type: CertType::CA,
        cert,
        key,
    })
}

fn create_cert(ca: &CA, domains: &[&str], cn: &str, is_client: bool) -> Result<CertPem> {
    let (days, cert_type) = if is_client {
        (Some(365), CertType::Client)
    } else {
        (Some(5 * 365), CertType::Server)
    };
    let (cert, key) = generate_cert(ca, domains, "CN", "Acme Inc.", cn, None, is_client, days)?;

    Ok(CertPem {
        cert_type,
        cert,
        key,
    })
}

async fn gen_files(pem: &CertPem) -> Result<()> {
    let name = match pem.cert_type {
        CertType::Client => "client",
        CertType::Server => "server",
        CertType::CA => "ca",
    };
    fs::write(format!("fixtures/{}.cert", name), pem.cert.as_bytes()).await?;
    fs::write(format!("fixtures/{}.key", name), pem.key.as_bytes()).await?;
    Ok(())
}

这个代码很简单,它先生成了一个 CA 证书,然后再生成服务器和客户端证书,全部存入刚创建的 fixtures 目录下。你需要 cargo run --examples gen_cert 运行一下这个命令,待会我们会在测试中用到这些证书和密钥。

关于TLS的具体细节不细说了。
对于 KV server 来说,使用 TLS 之后,整个协议的数据封装如下图所示:
rust的并发以及kv server网络处理和网络安全部分_第3张图片
估计很多人一听 TLS 或者 SSL,就头皮发麻,因为之前跟 openssl 打交道有过很多不好的经历。openssl 的代码库太庞杂,API 不友好,编译链接都很费劲。不过,在 Rust 下使用 TLS 的体验还是很不错的,Rust 对 openssl 有很不错的封装,也有不依赖 openssl 用 Rust 撰写的 rustls。tokio 进一步提供了符合 tokio 生态圈的 tls 支持,有 openssl 版本和 rustls 版本可选。
我们今天就用 tokio-rustls 来撰写 TLS 的支持。相信你在实现过程中可以看到,在应用程序中加入 TLS 协议来保护网络层,是多么轻松的一件事情。
先在 Cargo.toml 中添加 tokio-rustls:
然后创建 src/network/tls.rs,撰写如下代码(记得在 src/network/mod.rs 中引入这个文件哦):


use std::io::Cursor;
use std::sync::Arc;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::rustls::{internal::pemfile, Certificate, ClientConfig, ServerConfig};
use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, PrivateKey, RootCertStore};
use tokio_rustls::webpki::DNSNameRef;
use tokio_rustls::TlsConnector;
use tokio_rustls::{
    client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream, TlsAcceptor,
};

use crate::KvError;

/// KV Server 自己的 ALPN (Application-Layer Protocol Negotiation)
const ALPN_KV: &str = "kv";

/// 存放 TLS ServerConfig 并提供方法 accept 把底层的协议转换成 TLS
#[derive(Clone)]
pub struct TlsServerAcceptor {
    inner: Arc<ServerConfig>,
}

/// 存放 TLS Client 并提供方法 connect 把底层的协议转换成 TLS
#[derive(Clone)]
pub struct TlsClientConnector {
    pub config: Arc<ClientConfig>,
    pub domain: Arc<String>,
}

impl TlsClientConnector {
    /// 加载 client cert / CA cert,生成 ClientConfig
    pub fn new(
        domain: impl Into<String>,
        identity: Option<(&str, &str)>,
        server_ca: Option<&str>,
    ) -> Result<Self, KvError> {
        let mut config = ClientConfig::new();

        // 如果有客户端证书,加载之
        if let Some((cert, key)) = identity {
            let certs = load_certs(cert)?;
            let key = load_key(key)?;
            config.set_single_client_cert(certs, key)?;
        }

        // 加载本地信任的根证书链
        config.root_store = match rustls_native_certs::load_native_certs() {
            Ok(store) | Err((Some(store), _)) => store,
            Err((None, error)) => return Err(error.into()),
        };

        // 如果有签署服务器的 CA 证书,则加载它,这样服务器证书不在根证书链
        // 但是这个 CA 证书能验证它,也可以
        if let Some(cert) = server_ca {
            let mut buf = Cursor::new(cert);
            config.root_store.add_pem_file(&mut buf).unwrap();
        }

        Ok(Self {
            config: Arc::new(config),
            domain: Arc::new(domain.into()),
        })
    }

    /// 触发 TLS 协议,把底层的 stream 转换成 TLS stream
    pub async fn connect<S>(&self, stream: S) -> Result<ClientTlsStream<S>, KvError>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send,
    {
        let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())
            .map_err(|_| KvError::Internal("Invalid DNS name".into()))?;

        let stream = TlsConnector::from(self.config.clone())
            .connect(dns, stream)
            .await?;

        Ok(stream)
    }
}

impl TlsServerAcceptor {
    /// 加载 server cert / CA cert,生成 ServerConfig
    pub fn new(cert: &str, key: &str, client_ca: Option<&str>) -> Result<Self, KvError> {
        let certs = load_certs(cert)?;
        let key = load_key(key)?;

        let mut config = match client_ca {
            None => ServerConfig::new(NoClientAuth::new()),
            Some(cert) => {
                // 如果客户端证书是某个 CA 证书签发的,则把这个 CA 证书加载到信任链中
                let mut cert = Cursor::new(cert);
                let mut client_root_cert_store = RootCertStore::empty();
                client_root_cert_store
                    .add_pem_file(&mut cert)
                    .map_err(|_| KvError::CertifcateParseError("CA", "cert"))?;

                let client_auth = AllowAnyAuthenticatedClient::new(client_root_cert_store);
                ServerConfig::new(client_auth)
            }
        };

        // 加载服务器证书
        config
            .set_single_cert(certs, key)
            .map_err(|_| KvError::CertifcateParseError("server", "cert"))?;
        config.set_protocols(&[Vec::from(&ALPN_KV[..])]);

        Ok(Self {
            inner: Arc::new(config),
        })
    }

    /// 触发 TLS 协议,把底层的 stream 转换成 TLS stream
    pub async fn accept<S>(&self, stream: S) -> Result<ServerTlsStream<S>, KvError>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send,
    {
        let acceptor = TlsAcceptor::from(self.inner.clone());
        Ok(acceptor.accept(stream).await?)
    }
}

fn load_certs(cert: &str) -> Result<Vec<Certificate>, KvError> {
    let mut cert = Cursor::new(cert);
    pemfile::certs(&mut cert).map_err(|_| KvError::CertifcateParseError("server", "cert"))
}

fn load_key(key: &str) -> Result<PrivateKey, KvError> {
    let mut cursor = Cursor::new(key);

    // 先尝试用 PKCS8 加载私钥
    if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) {
        if !keys.is_empty() {
            return Ok(keys.remove(0));
        }
    }

    // 再尝试加载 RSA key
    cursor.set_position(0);
    if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) {
        if !keys.is_empty() {
            return Ok(keys.remove(0));
        }
    }

    // 不支持的私钥类型
    Err(KvError::CertifcateParseError("private", "key"))
}

虽然它有 100 多行,但主要的工作其实就是根据提供的证书,来生成 tokio-tls 需要的 ServerConfig / ClientConfig。处理完 config 后,这段代码的核心逻辑其实就是客户端的 connect() 方法和服务器的 accept() 方法,它们都接受一个满足 AsyncRead + AsyncWrite + Unpin + Send 的 stream。类似上一讲,我们不希望 TLS 代码只能接受 TcpStream,所以这里提供了一个泛型参数 S:
在使用 TlsConnector 或者 TlsAcceptor 处理完 connect/accept 后,我们得到了一个 TlsStream,它也满足 AsyncRead + AsyncWrite + Unpin + Send,后续的操作就可以在其上完成了。

由于我们一路以来良好的接口设计,尤其是 ProstClientStream / ProstServerStream 都接受泛型参数,使得 TLS 的代码可以无缝嵌入。比如客户端:


// 新加的代码
let connector = TlsClientConnector::new("kvserver.acme.inc", None, Some(ca_cert))?;

let stream = TcpStream::connect(addr).await?;

// 新加的代码
let stream = connector.connect(stream).await?;

let mut client = ProstClientStream::new(stream);

仅仅需要把传给 ProstClientStream 的 stream,从 TcpStream 换成生成的 TlsStream,就无缝支持了 TLS。

完整的服务器端


use anyhow::Result;
use kv3::{MemTable, ProstServerStream, Service, ServiceInner, TlsServerAcceptor};
use tokio::net::TcpListener;
use tracing::info;

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();
    let addr = "127.0.0.1:9527";

    // 以后从配置文件取
    let server_cert = include_str!("../fixtures/server.cert");
    let server_key = include_str!("../fixtures/server.key");

    let acceptor = TlsServerAcceptor::new(server_cert, server_key, None)?;
    let service: Service = ServiceInner::new(MemTable::new()).into();
    let listener = TcpListener::bind(addr).await?;
    info!("Start listening on {}", addr);
    loop {
        let tls = acceptor.clone();
        let (stream, addr) = listener.accept().await?;
        info!("Client {:?} connected", addr);
        let stream = tls.accept(stream).await?;
        let stream = ProstServerStream::new(stream, service.clone());
        tokio::spawn(async move { stream.process().await });
    }
}

客户端


use anyhow::Result;
use kv3::{CommandRequest, ProstClientStream, TlsClientConnector};
use tokio::net::TcpStream;
use tracing::info;

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();

    // 以后用配置替换
    let ca_cert = include_str!("../fixtures/ca.cert");

    let addr = "127.0.0.1:9527";
    // 连接服务器
    let connector = TlsClientConnector::new("kvserver.acme.inc", None, Some(ca_cert))?;
    let stream = TcpStream::connect(addr).await?;
    let stream = connector.connect(stream).await?;

    let mut client = ProstClientStream::new(stream);

    // 生成一个 HSET 命令
    let cmd = CommandRequest::new_hset("table1", "hello", "world".to_string().into());

    // 发送 HSET 命令
    let data = client.execute(cmd).await?;
    info!("Got response {:?}", data);

    Ok(())
}

和上一讲的代码项目相比,更新后的客户端和服务器代码,各自仅仅多了一行,就把 TcpStream 封装成了 TlsStream。这就是使用 trait 做面向接口编程的巨大威力,系统的各个组件可以来自不同的 crates,但只要其接口一致(或者我们创建 adapter 使其接口一致),就可以无缝插入。

你可能感兴趣的:(rust语言学习,rust,网络,web安全)