使用 Java NIO 实现 Simple Redis 服务端 客户端

最近在学习NIO,想找个demo来练练手,然后发现用java nio来简单实现redis应该挺有趣的。

Java NIO 概览

首先,java nio有3个重要的类:

  • ByteBuffer: 用于读写数据,实际上是byte数组的一个封装。它的使用方式还是比较有趣的,需要注意的一点是在写模式转换为读模式时需要flip()。
  • Channel: 用于维护服务端与客户端的通信通道。
  • Selector: 多路复用器,管理被注册的通道集合信息和就绪状态。还有SelectionKey用于维护Selector与Channel之间的关系,即Channel对哪些事件感兴趣,Selector会帮忙通知它。

然后是使用方法——java nio网络编程的主要使用流程如下:

  1. 创建Selector
  2. 创建SocketChannel
  3. 将SockerChannel注册到Selector上,设置SelectionKey
  4. 绑定端口或者开启连接
  5. 之后就不断从Selector中抽取准备好了的SelectionKey,然后对对应的Channel作处理

还有一点补充,SelectionKey可以包含一个任意类型的attachment,可以用来辅助Channel的处理。

Redis序列化协议

关于Redis序列化协议,可以参看这里
我只实现了数组、单行字符串、定长字符串。

这里我把单行字符串、定长字符串的解码称为单步解码,
即整个解码过程由若干单步解码组成。

拆包粘包

Redis的通信是基于TCP协议的,关于TCP网络编程,比较麻烦的一点就是处理拆包和粘包的问题。

JDK NIO并没有为我们直接提供处理拆包粘包的一些机制,所以需要我们自己处理。

这里处理粘包的思路很直接,即利用应用层的协议,因为Redis序列化协议已经很完善,遵循它的协议来实现处理Channel的代码,就能分割开不同请求的数据段。

而处理拆包会比较麻烦,它首先会分为两种情况:

  1. 目前读取的数据段不够,不足以进行单步解码操作。
  2. 目前读取的数据段刚好能进行单步解码操作,但总的请求并没有完全读取,相应的解码并没有完全完成。

对于第一种情况,基于协议规定的分割符或者已知的限定长度,不断等待读取事件准备好,处理读取事件,直到满足单步解码操作的要求。(这里可以保存之前部分解码的数据和状态,也可以不要,因为数据量较小,性能影响不大)。
而对于第二种情况,依然可以使用类似第一种情况的方式,不断等待数据完全准备好,若还没有准备好,则放弃之前解码完成的数据,等下次数据到来时再重新解码。不过由于数据已经解码了一部分,重新解码比较耗时,所以需要保存下之前解码的数据和状态,等下次数据到来时再继续解码。

这里我的方案是:在Decoder对象中保存目前的解码状态,保存的状态较为细节,第一种情况中的部分解码的数据和状态也保存了。然后让SelectionKey带上这个Decoder(通过attachment),当每次SelectionKey准备好时,取出这个Decocer进行继续处理。

代码部分

首先是核心部分——Redis序列化协议解码器:

1. 状态部分:

public class RespDecoder {

    private ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
    private SocketChannel channel;

    // 是否只拿到了CR,LF还没拿到
    private boolean isOnlyGetCR;

    // 字符串缓冲区
    private StringBuilder sb = new StringBuilder();

    // 目前的还没读取的定长字符串长度
    private int stringLength;

    // 目前读取到的字符串数组
    private List<String> wordList = new ArrayList<String>();

    // 目前读取到的mark
    private byte mark;

    // 目前还需读取到数组中的字符串数量
    private int size;

2. 主要解码逻辑
解码主体逻辑decode0()使用递归实现,比较方便。
decode()方法处理异常,设置解码完成与否的标识,返回给调用者。

    public boolean decode() throws Exception {
        boolean isComplete = true;
        try {
            decode0();
        } catch (ReadEmptyException e) {
            // 捕获到异常,则返回false,提示调用者目前还未解码完成。
            isComplete = false;
        }
        return isComplete;
    }

    private void decode0() throws Exception {
        // mark默认为0,若不为0,则为中间状态
        mark = (mark != 0 ? mark : readOneByte());
        if (mark == '*') {
            // size默认为0,若不为0,则为中间状态
            size = (size != 0 ? size : readInteger());

            while ((size--) > 0) {
                decode0();   // 递归解析
            }
        } else if (mark == '$') {
            // stringLength默认为0,若不为0,则之前读取了一部分的字符串
            stringLength = (stringLength != 0 ? stringLength : readInteger() + 2);     // +2 : 加上CRLF
            readFixString();
            // 没有抛异常就保存字符串
            saveString();
        } else if (mark == '+') {
            readString();
            // 没有抛异常就保存字符串
            saveString();
        }
    }

3. 读取数据到buffer
若目前Channel没有数据可读,则抛出ReadEmptyException异常,上层调用代码捕获到异常后可以进行相应处理

     /**
     * 若byteBuffer为空,读取数据到byteBuffer
     * @throws Exception
     */
    private void readToBuffer() throws Exception{
        if (!byteBuffer.hasRemaining()) {
            byteBuffer.clear();
            int count = channel.read(byteBuffer);
            if (count < 0) {
                throw new RuntimeException("connection error");
            } else if (count == 0) {
                throw new ReadEmptyException("read empty");
            }
            byteBuffer.flip();
        }
    }

4. 读取定长字符串
读取字符串等单步解码结构差不多,这里以读取定长字符串为例。这里我使用了递归的方法,来递归读取剩余的数据。若readToBuffer()抛出异常,则不进行处理,直接抛出该异常给上层。

     /**
     * 读一个定长字符串
     * @return
     */
    private void readFixString() throws Exception{

        readToBuffer();

        byte[] bytes;
        if (byteBuffer.remaining() < stringLength) {
            int currentSize = byteBuffer.remaining();
            bytes = new byte[currentSize];
            byteBuffer.get(bytes);
            String str = new String(bytes, "UTF-8");
            sb.append(str);
            stringLength -= currentSize;
            readFixString();     // 递归读取剩余的数据
        } else {
            bytes = new byte[stringLength];
            byteBuffer.get(bytes);
            String str = new String(bytes, "UTF-8");
            sb.append(str.replaceAll("\r|\n", ""));     // 去除CRLF
            stringLength = 0;
        }

    }

然后是服务器代码,常规的NIO操作。这里用ConcurrentHashMap来实现Redis缓存业务,实际上用HashMap就够了,因为这个服务器代码只有一个线程。

public class RedisServer {

    private Map<String, String> redisMap = new ConcurrentHashMap<String, String>();

    public void start() throws Exception{

        ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
        ServerSocket serverSocket = serverSocketChannel.socket();
        Selector selector = Selector.open();
        serverSocketChannel.configureBlocking(false);
        serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
        // 绑定6379
        serverSocket.bind(new InetSocketAddress(6379));
        System.out.println("Listen to " + 6379);

        while (true) {
            int n = selector.select();
            if (n == 0) {
                continue;
            }
            Iterator it = selector.selectedKeys().iterator();

            while (it.hasNext()) {
                SelectionKey key = (SelectionKey) it.next();
                if (key.isAcceptable()) {
                    ServerSocketChannel server = (ServerSocketChannel) key.channel();
                    SocketChannel socketChannel = server.accept();

                    System.out.println(socketChannel.getLocalAddress() + " accepted");

                    socketChannel.configureBlocking(false);
                    socketChannel.register(selector, SelectionKey.OP_READ);


                }

                if (key.isReadable()) {
                    SocketChannel socketChannel = (SocketChannel) key.channel();
                    if (key.attachment() == null) {
                        key.attach(new RespDecoder(socketChannel));
                    }
                    RespDecoder decoder = (RespDecoder) key.attachment();

                    try {
                        boolean isComplete = decoder.decode();
                        if (isComplete) {
                            List<String> wordList = decoder.getWordList();
                            System.out.println(Arrays.toString(wordList.toArray()));
                            String message = operate(wordList);
                            // 解码结束,清空decoder的状态
                            decoder.clear();

                            // 发送到客户端
                            send(message, socketChannel);
                        }
                    } catch (Exception e) {
                        key.cancel();
                        socketChannel.socket().close();
                        socketChannel.close();
                    }



                }
                // 清除处理过的键
                it.remove();
            }
        }
    }

    private String operate(List<String> wordList) throws Exception{
        String result = null;
        if (wordList.get(0).equals("set")) {
            redisMap.put(wordList.get(1), wordList.get(2));
            result = "OK";
        } else if (wordList.get(0).equals("get")) {
            result = redisMap.get(wordList.get(1));
        }
        return result;
    }

    private void send(String message, SocketChannel channel) throws Exception{
        ByteBuffer writeBuffer = RespEncoder.encode(message);

        while (writeBuffer.hasRemaining()) {
            channel.write(writeBuffer);
        }
    }
}

对于编码器,直接按Redis序列化协议去编码就好了,应该没什么问题。
客户端代码与服务器类似,就不列出来了。
有兴趣可以点击后文给出的源码地址查看。

演示部分

1. 用Redis自带的客户端去连接我们的Simple Redis服务器:

开启我们的Simple Redis服务器

使用Redis自带的命令行客户端

2. 用我们的Simple Redis客户端去连接Redis服务器:

开启Redis服务器

使用我们的Simple Redis命令行客户端

用Redis自带的命令行客户端验证

3. 用我们的Simple Redis客户端去连接我们的Simple Redis服务器:

开启我们的Simple Redis服务器

使用我们的Simple Redis命令行客户端

最后

源码地址

github
源码中还有一个Netty的版本。

再最后

本人刚接触NIO不久,若大家发现有问题请不吝赐教,谢谢。

你可能感兴趣的:(nio,java,redis)