C++ 实现websocket 简单的服务器

前言
打算写一个WebSocket服务器来练练手,它是基于tcp实现的,与生俱来的优点较之http是全双工的,即服务端可主动向客户端推送数据,亦可请求响应的模式来进行数据传输

WebSocket讲解
网上有很多对WebSocket的格式进行了充分的讲解,我们搬来用用。
参考自 https://segmentfault.com/a/1190000012948613 感谢大神!!!


握手
首先是客户端和服务器建立连接,即握手操作。

GET / HTTP/1.1
Host: localhost:8080
Origin: http://127.0.0.1:3000
Connection: Upgrade
Upgrade: websocket
Sec-WebSocket-Version: 13
Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==


这是客户端的请求
- Connection: Upgrade:表示要升级协议
- Upgrade: websocket:表示要升级到 websocket 协议。
- Sec-WebSocket-Version: 13:表示 websocket 的版本。如果服务端不支持该版本,需要返回一个 Sec-WebSocket-Versionheader ,里面包含服务端支持的版本号。
- Sec-WebSocket-Key:与后面服务端响应首部的 Sec-WebSocket-Accept 是配套的,提供基本的防护,比如恶意的连接,或者无意的连接

HTTP/1.1 101 Switching Protocols
Connection:Upgrade
Upgrade: websocket
Sec-WebSocket-Accept: Oy4NRAQ13jhfONC7bP8dTKb4PTU=


这是服务器响应。
- Sec-WebSocket-Accept
伪代码如下:
>toBase64(sha1(Sec-WebSocket-Key + 258EAFA5-E914-47DA-95CA-C5AB0DC85B11))

接下来就是数据的传输(客户端与服务器)
C++ 实现websocket 简单的服务器_第1张图片

数据帧的格式
> fin占1位,表示此次传输数据包是否传输完,因为websocket里有分片传输,1传输完,0未传输完
> rsv1,2,3占3位,是做扩展用
> opcode占4位,是数据帧的操作类型包括close,ping,pong,binary,text等
> mask占1位,1表示数据要经过掩码运算,0表示不需要掩码操作,服务器不需要掩码,客户端需要
>payload len占7位,表示数据的长度,分为[0, 126),126,127
~ 第一种情况数据长度就是等于该值
~ 第二种情况用其后的2个字节表示数据的长度
~ 第三种情况用其后的8个字节表示数据的长度
>making-key占4字节,如果mask设置1,使用这个key与数据进行掩码操作来获得正式的数据
>之后便是数据部分,分为扩展数据和应用数据,我们在这里只讨论应用数据
>

代码
代码结构分为:
* Server: 管理客户端,接收客户端数据,同时向客户端写入数据
* ClientItem, 接受到的客户端的实例(可理解为socket)
* WebSokcetMsg 数据包的解析和封装
* WebSocketController 业务处理类,通过Server派发给任务,同时自己可通过回调写客户端写数据

1
我们首先看WebSocketMsg:数据包的封装和解析

#ifndef __WEBSOCKETMSG_H__
#define __WEBSOCKETMSG_H__

#include 
#include 

class WebSocketMsg
{
public:
    WebSocketMsg();
    ~WebSocketMsg();

    struct WebSocketPkt
    {
        enum MsgType {
            MsgType_Handshake,
            MsgType_FrameData,
        };

        enum OpcodeType {
            OpcodeType_Continue  = 0x0,
            OpcodeType_Text      = 0x1,
            OpcodeType_Binary    = 0x2,
            OpcodeType_Close     = 0x8,
            OpcodeType_Ping      = 0x9,
            OpcodeType_Pong      = 0xA,
        };

        void resetFrameData();

        MsgType                            msg_type_;
        std::map header_map_;

        uint8_t                            fin_;            // 1bit
        uint8_t                            opcode_;         // 4bit
        uint8_t                            mask_;           // 1bit
        uint8_t                            masking_key_[4]; // 0 or 4 bytes
        uint64_t                           payload_length_; // 1 or 2 or 8 bytes
        std::string                        data_;
    };

    /**
    * return: -1:error, 0:continue, >0:done
    */
    int fromFrameDataPkt(int nread, const char *buf);
    bool fromHandshakePkt(int nread, const char *buf);

    WebSocketMsg::WebSocketPkt::OpcodeType requestOpcode() {
        return (WebSocketMsg::WebSocketPkt::OpcodeType)request_.opcode_;
    }

    std::string requestData() {
        return request_.data_;
    }

    void resetRequestData() {
        request_.resetFrameData();
    }

    std::string toHandshakePkt();
    std::string toFrameDataPkt(const std::string &data, 
                               WebSocketMsg::WebSocketPkt::OpcodeType type = 
                                    WebSocketMsg::WebSocketPkt::OpcodeType_Text);

private:
    WebSocketPkt request_;
    WebSocketPkt response_;
};

#endif // __WEBSOCKETMSG_H__

定义请求和响应的WebSocketPkt 对象来做处理,请求来了使用request_处理,要是组合响应信息通过response_,只是这样设想,可看具体实现

#include "WebSocketMsg.h"

#include "base/string_util.h"
#include "base/sha1.h"
#include "base/base64.h"

const char s_Key[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const char s_ReqWSKey[] = "Sec-WebSocket-Key";

WebSocketMsg::WebSocketMsg()
{
    request_.resetFrameData();
    response_.resetFrameData();
}

WebSocketMsg::~WebSocketMsg()
{
}

bool WebSocketMsg::fromHandshakePkt(int nread, const char *buf)
{
    request_.msg_type_ = WebSocketPkt::MsgType_Handshake;

    std::string str(buf, nread);
    std::vector tokens;

    int size = Tokenize(str, "\r\n", &tokens);
    for (int i = 0; i < size; ++i) {
        const std::string &line = tokens.at(i);
        std::vector lineEle;

        int eleSize = Tokenize(line, ": ", &lineEle);
        if (eleSize == 2) {
            request_.header_map_[lineEle.at(0)] = lineEle.at(1);
        }
    }

    return true;
}

int WebSocketMsg::fromFrameDataPkt(int nread, const char *buf)
{
    request_.msg_type_ = WebSocketPkt::MsgType_FrameData;
    unsigned bytePos = 0;

    request_.fin_ = (buf[bytePos] >> 7);
    request_.opcode_ = buf[bytePos] & 0x0F;
    bytePos++;

    request_.mask_ = (buf[bytePos] >> 7);
    request_.payload_length_ = buf[bytePos] & 0x7F;
    bytePos++;

    if (request_.payload_length_ == 126) {
        uint16_t length = 0;
        memcpy(&length, buf + bytePos, 2);

        request_.payload_length_ = ntohs(length);
        bytePos += 2;
    }
    else if (request_.payload_length_ == 127) {
        long long length = 0;
        memcpy(&length, buf + bytePos, 8);

        request_.payload_length_ = ntohll(length);
        bytePos += 8;
    }

    if (request_.mask_ != 0) {
        for (int i = 0; i < 4; ++i) {
            request_.masking_key_[i] = buf[bytePos++];
        }
    }

    std::string s = "";
    if (request_.mask_ == 0) {
        s = std::string(buf + bytePos, (unsigned)request_.payload_length_);
    }
    else {
        for (unsigned i = 0; i < request_.payload_length_; ++i) {
            unsigned j = i % 4U;
            char c = buf[bytePos + i] ^ request_.masking_key_[j];
            s.push_back(c);
        }
    }
    request_.data_.append(s);
    bytePos += (unsigned)request_.payload_length_;

    return request_.fin_;
}

std::string WebSocketMsg::toHandshakePkt()
{
    auto &headerMap = request_.header_map_;
    if (headerMap.count(s_ReqWSKey) == 0) {
        return "";
    }

    std::string pkt;
    pkt.append("HTTP/1.1 101 Switching Protocols\r\n");
    pkt.append("Connection: upgrade\r\n");
    pkt.append("Upgrade: websocket\r\n");
    pkt.append("Sec-WebSocket-Accept: ");

    std::string SecWSKey = headerMap.at(s_ReqWSKey) + s_Key;
    bool rc = base::Base64Encode(base::SHA1HashString(SecWSKey), &SecWSKey);

    if (rc) {
        pkt.append(SecWSKey + "\r\n\r\n");
    }
    else {
        pkt = "";
    }
    return pkt;
}

std::string WebSocketMsg::toFrameDataPkt(const std::string &data,
                                         WebSocketMsg::WebSocketPkt::OpcodeType type)
{
    unsigned size = data.size();
    std::string frame;

    uint8_t c = (1 << 7);
    c = c | type;
    frame.push_back((char)c);

    uint8_t paylength = 0;
    if (size < 126U) {
        paylength = size;
        frame.push_back((char)paylength);
    }
    else if (size >= 126U && size <= 0xFFFFU) {
        paylength = 126;
        frame.push_back(paylength);

        uint16_t l = htons(size);
        char buf[2] = {0};
        memcpy(buf, &l, 2);
        frame.append(buf);
    }
    else {
        paylength = 127;
        frame.push_back(paylength);

        uint64_t l = htonll((int64_t)size);
        char buf[8] = {0};
        memcpy(buf, &l, 8);
        frame.append(buf);
    }

    frame.append(data);
    return frame;
}

void WebSocketMsg::WebSocketPkt::resetFrameData()
{
    fin_ = 0;
    opcode_ = 0;
    mask_ = 0;         
    memset(masking_key_, 0, sizeof(masking_key_));
    payload_length_ = 0;
    data_ = "";
}

这是具体实现,fromHandshakePkt函数 即解析握手包的收据,fromFrameDataPkt 函数即解析数据帧的数据,其中使用如下来获得真实数据

for (unsigned i = 0; i < request_.payload_length_; ++i) {
    unsigned j = i % 4U;
    char c = buf[bytePos + i] ^ request_.masking_key_[j];
    s.push_back(c);
}

toHandshakePkt和toFrameDataPkt来封装响应的握手包及向客户端传输的数据包
 

2
然后我们来看Server看服务器是如何实现,我们使用libuv库来实现,libuv是事件驱动i/o模型,nodejs便是使用的它,适合i/o密集型的服务器

#ifndef __SERVER_H__
#define __SERVER_H__

#include 
#include 

#include "uv.h"

#include "ClientItem.h"
#include "WebSocketController.h"

typedef struct {
    uv_write_t req_;
    uv_buf_t   buf_;
} write_response_t;

class Server
{
public:
    Server(int port = 8080);
    ~Server();

    int run();

    void doWork(ClientItem *item);
    void writeFrameData(int64_t id, const std::string &data);

    inline void pushClient(const ClientItem &ci);
    inline ClientItem* client(int64_t id);
    void closeClient(int64_t id);
    void removeClient(int64_t id);

    inline const std::map& clientMap() const;

    inline int64_t increaseId();

private:
    static void on_new_connection(uv_stream_t* server, int status);
    static void alloc_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf);
    static void read_msg(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf);

    static void write_msg_ret(uv_write_t *response, int status);
    static void timer_out(uv_timer_t* handle);

private:
    uv_loop_t                           *loop_;
    uv_tcp_t                             server_;
    int                                  port_;
    sockaddr_in                          addr_;
                                        
    std::map        client_map_;
    int64_t                              increase_id_;

    uv_timer_t                           repeat_timer_;
    std::unique_ptr ws_controller_;
};

#endif // __SERVER_H__

Server类的声明简要介绍:
 write_response_t是向客户端写数据的结构体
 libuv使用了好多C的回调函数,我们把它定义成static类型,有新客户端连接,读数据,定时器等
 我们把写数据writeFrameData定义成成员函数,方便我们操作;
 定时器的作用是我们隔段时间去检查客户端是否还活着,不活着就关闭。
 *client系列函数是针对客户端实例的管理

首先我们先看入口函数:run

int Server::run()
{
    uv_tcp_init(loop_, &server_);
    uv_ip4_addr("0.0.0.0", port_, &addr_);
    uv_tcp_bind(&server_, (const sockaddr *)&addr_, 0);

    server_.data = this;

    int rc = uv_listen((uv_stream_t *)&server_, DEFAULT_BACKLOG, &Server::on_new_connection);
    if (rc) {
        std::cout << "listen error:" << uv_strerror(rc) << std::endl;
        return -1;
    }

    repeat_timer_.data = this;
    uv_timer_init(loop_, &repeat_timer_);
    uv_timer_start(&repeat_timer_, timer_out, 1000, 1000);

    return uv_run(loop_, UV_RUN_DEFAULT);
}

监听port_端口进行初始化,设置回调函数on_new_connection有新的连接到来的函数,启动定时器,进入循环。


然后看“连接到来”的函数:on_new_connection

void Server::on_new_connection(uv_stream_t *server, int status)
{
    if (status < 0) {
        std::cout << "new connection error" << uv_strerror(status) << std::endl;
        return ;
    }

    uv_tcp_t *client = (uv_tcp_t *)malloc(sizeof(uv_tcp_t));
    uv_tcp_init(server->loop, client);

    if (uv_accept(server, (uv_stream_t *)client) == 0) {
        std::cout << "new connection" << std::endl;

        Server *s = (Server *)(server->data);
        ClientItem item(s->increaseId(), client, (int64_t)base::Time::Now().ToDoubleT());
        item.user_data_ = s;
        s->pushClient(item);

        client->data = s->client(item.id_);
        uv_read_start((uv_stream_t *)client, alloc_buffer, read_msg);
    }
    else {
        uv_close((uv_handle_t *)client, NULL);
        free(client);
    }
}

连接到来时初始化客户端实例,为客户都安实例生成id,放置在map里来管理,然后开始读取数据:

void Server::read_msg(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf)
{
    if (nread <= 0) {
        std::cout << "read error" << uv_err_name(nread) << std::endl;
        free(buf->base);
        return ;
    }

    ClientItem *item = (ClientItem *)(client->data);
    Server *s = (Server *)item->user_data_;

    if (item == nullptr) {
        return ;
    }
    item->pong_time_ = (int64_t) base::Time::Now().ToDoubleT();

    WebSocketMsg msg;
    if (!item->is_connected_) {
        bool rc = msg.fromHandshakePkt(nread, buf->base);
        if (!rc) {
            s->closeClient(item->id_);
        }
        else {
            item->is_connected_ = true;
            item->msg_ = msg;

            std::string pkt = msg.toHandshakePkt();
            int64_t id = item->id_;

            if (pkt.empty()) {
                s->closeClient(id);
            }
            else {
                s->writeFrameData(id, pkt);
            }
        }
    }
    else {
        int rc = item->msg_.fromFrameDataPkt(nread, buf->base);
        if (rc > 0) {
            s->doWork(item);
        }
    }

    free(buf->base);
}

读取到数据后,首先看是不是第一次连接,如果是第一次连接则解析握手包,回传响应,如果不是第一次读数据,解析数据帧,然后执行。

接着我们来看下dowork做的事情:
 

void Server::doWork(ClientItem *item)
{
    int opcode = item->msg_.requestOpcode();
    switch (opcode) {

    case WebSocketMsg::WebSocketPkt::OpcodeType_Close:
        closeClient(item->id_);
        return;

    case WebSocketMsg::WebSocketPkt::OpcodeType_Ping:
    {
        WebSocketMsg msg;
        std::string data = msg.toFrameDataPkt("", WebSocketMsg::WebSocketPkt::OpcodeType_Pong);
        writeFrameData(item->id_, data);
    }
        return;

    case WebSocketMsg::WebSocketPkt::OpcodeType_Binary:
    case WebSocketMsg::WebSocketPkt::OpcodeType_Text:
        if (item != nullptr) {
            ws_controller_->doWork(item);
        }
        return;

    default:
        return ;
    }
}

如果是数据类型,则派发给controller处理,其余的在此处处理。
我们会给controller传递写数据的回调函数,当controller有数据写入时可直接调用
writeFrameData:
 

void Server::writeFrameData(int64_t id, const std::string &data)
{
    ClientItem *item = client(id);
    if (item == nullptr) {
        return ;
    }

    uv_handle_t *handle = (uv_handle_t *)item->client_;
    if (handle == nullptr) {
        closeClient(id);
        ws_controller_->setItem(nullptr);
    }

    write_response_t *response = (write_response_t*)malloc(sizeof(write_response_t));

    int size = data.size();
    alloc_buffer(handle, size, &response->buf_);

    for (int i = 0; i < size; ++i) {
        response->buf_.base[i] = data.at(i);
    }

    uv_write((uv_write_t *)response, (uv_stream_t *)handle,
             &response->buf_, 1, write_msg_ret);
}

接下来我们看下定时器里做的工作:

void Server::timer_out(uv_timer_t *handle)
{
    Server *s = (Server *)handle->data;
    const std::map &clientMap = s->clientMap();

    if (clientMap.size() == 0) {
        return ;
    }

    auto iter = clientMap.cbegin();
    for (;iter != clientMap.cend(); iter++) {
        const ClientItem &item = iter->second;
        if (item.isDead()) {
            s->closeClient(item.id_);
            continue ;
        }

        if (item.isToPing()) {
            WebSocketMsg msg;
            std::string frame = msg.toFrameDataPkt("", WebSocketMsg::WebSocketPkt::OpcodeType_Ping);

            s->writeFrameData(item.id_, frame);
        }
    }
}

遍历保存客户端实例的map来查看客户端是已经不再存活还是需要我们去ping他,维持心跳。

3
至此我们把大致的代码结构讲解完了,我们最后看些ClientItem的结构

#ifndef __CLIENTITEM_H__
#define __CLIENTITEM_H__

#include 

#include "uv.h"
#include "base/time.h"

#include "WebSocketMsg.h"

struct ClientItem
{
    ClientItem(int64_t id, uv_tcp_t *client, 
        int64_t pong_time = 0, int64_t ping_time = 0) :
        id_(id),
        client_(client),
        pong_time_(pong_time),
        ping_time_(ping_time),
        is_connected_(false),
        user_data_(nullptr) {}

    ClientItem() {
        reset();
    }

    bool isDead() const {
        if (ping_time_ == 0 || pong_time_ == 0) {
            return false;
        }

        if (ping_time_ - pong_time_ > 10) {
            return true;
        }

        return false;
    }

    bool isToPing() const {
        using namespace base;
        base::Time t = base::Time::Now();
        int64_t nowSeconds = (int64_t)t.ToDoubleT();

        if (nowSeconds - pong_time_ > 5) {
            return true;
        }

        return false;
    }

    void close();
    void reset();

    int64_t                              id_;
    uv_tcp_t                            *client_;
    int64_t                              pong_time_;  // s
    int64_t                              ping_time_;  // s
    bool                                 is_connected_;
    void                                *user_data_;
    WebSocketMsg                         msg_;
};

#endif // __CLIENTITEM_H__

好了,由于业务逻辑那里可以自己定制,我就不再赘述了,篇幅有点长,感谢有时间读完。吐槽csdn的编辑器,让我搞了七八遍
最后附代码下载连接
https://download.csdn.net/download/leapmotion/10835888

 

 

 

你可能感兴趣的:(网络,websocket)