linux 下websocket server demo例程

websocket的协议,原理参考文档:WebScoket 规范 + WebSocket 协议。


主要步骤:1、创建socket套接字进行监听客户端;
  2、握手,当与客户端建立tcp连接后,客户端会发送websocket请求,此时,服务器端需要提取客户端在websocket请求中包含一个握手的唯一Key,服务端在拿到这个Key后,需要加入一个GUID,然后进行sha1的加密,再转换成base64,最后再发回到客户端。这样就完成了一次握手。此种握手方式是针对chrome websocket 13的版本,其他版本的可能会有所不同。握手只需一次即可,握手成功后进行正常的数据通信。

  3、后面的过程跟tcp通信一样;


附上代码:

websocket.c

/*********************************************************************************
 *      Copyright:  (C) 2017 Yang Zheng  
 *                  All rights reserved.
 *
 *       Filename:  websocket.c
 *    Description:  This file 
 *                 
 *        Version:  1.0.0(08/17/2017~)
 *         Author:  Yang Zheng 
 *      ChangeLog:  1, Release initial version on "08/17/2017 02:03:22 PM"
 *                 
 ********************************************************************************/
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#include "base64.h"
#include "sha1.h"
#include "intlib.h"


#define REQUEST_LEN_MAX         1024
#define DEFEULT_SERVER_PORT     8000
#define WEB_SOCKET_KEY_LEN_MAX  256
#define RESPONSE_HEADER_LEN_MAX 1024
#define PER_LINE_MAX            256


/**************************************************************************************
 * Function Name: extract_client_key 
 *   Description: 提取客户端发送的handshake key值
 *    Input Args: @buffer 客户端发送的握手数据
 *   Output Args: 输出客户端发来handshake key
 *  Return Value: server_key 客户端发来的handshake key
 *************************************************************************************/
static char *extract_client_key(const char * buffer)
{
    char    *key = NULL;
    char    *start = NULL; /* 要提取字符串的起始地址 */
    char    *flag = "Sec-WebSocket-Key: ";
    int     i = 0;
    int     buf_len = 0;

    if(NULL == buffer) {
        printf("buffer is NULL.\n");
        return NULL;
    }

    key=(char *)malloc(WEB_SOCKET_KEY_LEN_MAX);
    if (NULL == key) {
        printf("key alloc failure.\n");
        return NULL;
    }
    memset(key,0, WEB_SOCKET_KEY_LEN_MAX);


    start = strstr(buffer, flag);
    if(NULL == start) {
        printf("start is NULL.\n");
        return NULL;
    }

    start += strlen(flag);
    buf_len = strlen(buffer);
    for(i=0;i>8 & 0xFF);
        data[3] = (n & 0xFF);
        memcpy(data+4,message,n);    
        *len=n+4;
    } else {
        // 暂不处理超长内容  
        *len=0;
    }

    return data;
} /* ----- End of construct_packet_data()  ----- */

/**************************************************************************************
 * Function Name: response
 *   Description: 响应客户端
 *    Input Args: @conn_fd 连接句柄
 *                @message 发送的数据
 *   Output Args: 无
 *  Return Value: 无
 *************************************************************************************/
void response(int conn_fd, const char *message)
{
    char *data = NULL;
    unsigned long n=0;

    if(!conn_fd) {
        printf("conn_fd is error.\n");
        return ;
    }

    if(NULL == message) {
        printf("message is NULL.\n");
        return ;
    }

    data = construct_packet_data(message, &n); 
    if(NULL == data || n <= 0)
    {
        printf("data is empty!\n");
        return ;
    } 

    write(conn_fd, data, n);

    if (NULL == data) {
        free(data);
        data = NULL;
    }
} /* ----- End of response()  ----- */


/********************************************************************************
 * Function Name:
 *   Description:
 *    Input Args:
 *   Output Args:
 *  Return Value:
 ********************************************************************************/
int main(int argc, char *argv[])
{
    int                 listen_fd;
    int                 conn_fd;
    char                buf[REQUEST_LEN_MAX];
    char                *data = NULL;
    char                str[INET_ADDRSTRLEN];
    char                *sec_websocket_key = NULL;
    int                 n;
    int                 connected = 0;//0:not connect.1:connected.
    int                 port = DEFEULT_SERVER_PORT;
    struct sockaddr_in  servaddr;
    struct sockaddr_in  cliaddr;
    socklen_t           cliaddr_len;

    if(argc > 1) 
        port = atoi(argv[1]);

    if(port<=0 || port>0xFFFF) {
        printf("Port(%d) is out of range(1-%d)\n", port, 0xFFFF);
        return -1;
    }
    listen_fd = socket(AF_INET, SOCK_STREAM, 0);

    bzero(&servaddr, sizeof(servaddr));
    servaddr.sin_family = AF_INET;
    servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
    servaddr.sin_port = htons(port);

    bind(listen_fd, (struct sockaddr *)&servaddr, sizeof(servaddr));

    listen(listen_fd, 20);

    printf("Listen %d\nAccepting connections ...\n",port);
    cliaddr_len = sizeof(cliaddr);
    conn_fd = accept(listen_fd, (struct sockaddr *)&cliaddr, &cliaddr_len);
    printf("From %s at PORT %d\n", \
            inet_ntop(AF_INET, &cliaddr.sin_addr, str, sizeof(str)),
            ntohs(cliaddr.sin_port));

    while (1)
    {
        memset(buf, 0, REQUEST_LEN_MAX);
        n = read(conn_fd, buf, REQUEST_LEN_MAX);	
        printf("---------------------\n");

        if(!connected) {
            printf("read:%d\n%s\n", n, buf);
            sec_websocket_key = calculate_accept_key(buf);	
            websocket_shakehand(conn_fd, sec_websocket_key);
            if (sec_websocket_key != NULL) {
                free(sec_websocket_key);
                sec_websocket_key = NULL;
            }
            connected=1;
            continue;
        }

        data = deal_data(buf, n);
        response(conn_fd, data);
    }

    close(conn_fd);
    return 0;
} /* ----- End of main() ----- */

sha1.c

/*********************************************************************************
 *      Copyright:  (C) 2017 Yang Zheng  
 *                  All rights reserved.
 *
 *       Filename:  sha1.c
 *    Description:  This file 
 *                 
 *        Version:  1.0.0(08/17/2017~)
 *         Author:  Yang Zheng 
 *      ChangeLog:  1, Release initial version on "08/17/2017 02:08:20 PM"
 *                 
 ********************************************************************************/
#include "sha1.h"

static void sha1_process_message_block(sha1_context*);
static void sha1_pad_message(sha1_context*);

static void sha1_reset(sha1_context *context) // 初始化动作
{
    context->length_low             = 0;
    context->length_high            = 0;
    context->message_block_index    = 0;

    context->message_digest[0]      = 0x67452301;
    context->message_digest[1]      = 0xEFCDAB89;
    context->message_digest[2]      = 0x98BADCFE;
    context->message_digest[3]      = 0x10325476;
    context->message_digest[4]      = 0xC3D2E1F0;

    context->computed   = 0;
    context->corrupted  = 0;
}


static int sha1_result(sha1_context *context) // 成功返回1,失败返回0
{
    if (context->corrupted) {
        return 0;
    }
    if (!context->computed) {
        sha1_pad_message(context);
        context->computed = 1;
    }
    return 1;
}


static void sha1_input(sha1_context *context,const char *message_array,unsigned length)
{
    if (!length) return;

    if (context->computed || context->corrupted){
        context->corrupted = 1;
        return;
    }

    while(length-- && !context->corrupted){
        context->message_block[context->message_block_index++] = (*message_array & 0xFF);

        context->length_low += 8;

        context->length_low &= 0xFFFFFFFF;
        if (context->length_low == 0){
            context->length_high++;
            context->length_high &= 0xFFFFFFFF;
            if (context->length_high == 0) context->corrupted = 1;
        }

        if (context->message_block_index == 64){
            sha1_process_message_block(context);
        }
        message_array++;
    }
}

static void sha1_process_message_block(sha1_context *context)
{
    const unsigned K[] = {0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6 };
    int         t;                
    unsigned    temp;             
    unsigned    W[80];            
    unsigned    A, B, C, D, E;    

    for(t = 0; t < 16; t++) {
        W[t] = ((unsigned) context->message_block[t * 4]) << 24;
        W[t] |= ((unsigned) context->message_block[t * 4 + 1]) << 16;
        W[t] |= ((unsigned) context->message_block[t * 4 + 2]) << 8;
        W[t] |= ((unsigned) context->message_block[t * 4 + 3]);
    }

    for(t = 16; t < 80; t++)  W[t] = SHA1_CIRCULAR_SHIFT(1,W[t-3] ^ W[t-8] ^ W[t-14] ^ W[t-16]);

    A = context->message_digest[0];
    B = context->message_digest[1];
    C = context->message_digest[2];
    D = context->message_digest[3];
    E = context->message_digest[4];

    for(t = 0; t < 20; t++) {
        temp =  SHA1_CIRCULAR_SHIFT(5,A) + ((B & C) | ((~B) & D)) + E + W[t] + K[0];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1_CIRCULAR_SHIFT(30,B);
        B = A;
        A = temp;
    }
    for(t = 20; t < 40; t++) {
        temp = SHA1_CIRCULAR_SHIFT(5,A) + (B ^ C ^ D) + E + W[t] + K[1];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1_CIRCULAR_SHIFT(30,B);
        B = A;
        A = temp;
    }
    for(t = 40; t < 60; t++) {
        temp = SHA1_CIRCULAR_SHIFT(5,A) + ((B & C) | (B & D) | (C & D)) + E + W[t] + K[2];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1_CIRCULAR_SHIFT(30,B);
        B = A;
        A = temp;
    }
    for(t = 60; t < 80; t++) {
        temp = SHA1_CIRCULAR_SHIFT(5,A) + (B ^ C ^ D) + E + W[t] + K[3];
        temp &= 0xFFFFFFFF;
        E = D;
        D = C;
        C = SHA1_CIRCULAR_SHIFT(30,B);
        B = A;
        A = temp;
    }
    context->message_digest[0] = (context->message_digest[0] + A) & 0xFFFFFFFF;
    context->message_digest[1] = (context->message_digest[1] + B) & 0xFFFFFFFF;
    context->message_digest[2] = (context->message_digest[2] + C) & 0xFFFFFFFF;
    context->message_digest[3] = (context->message_digest[3] + D) & 0xFFFFFFFF;
    context->message_digest[4] = (context->message_digest[4] + E) & 0xFFFFFFFF;
    context->message_block_index = 0;
}

static void sha1_pad_message(sha1_context* context)
{
    if (context->message_block_index > 55) {
        context->message_block[context->message_block_index++] = 0x80;
        while(context->message_block_index < 64)  context->message_block[context->message_block_index++] = 0;
        sha1_process_message_block(context);
        while(context->message_block_index < 56) context->message_block[context->message_block_index++] = 0;
    } else {
        context->message_block[context->message_block_index++] = 0x80;
        while(context->message_block_index < 56) context->message_block[context->message_block_index++] = 0;
    }
    context->message_block[56] = (context->length_high >> 24 ) & 0xFF;
    context->message_block[57] = (context->length_high >> 16 ) & 0xFF;
    context->message_block[58] = (context->length_high >> 8 ) & 0xFF;
    context->message_block[59] = (context->length_high) & 0xFF;
    context->message_block[60] = (context->length_low >> 24 ) & 0xFF;
    context->message_block[61] = (context->length_low >> 16 ) & 0xFF;
    context->message_block[62] = (context->length_low >> 8 ) & 0xFF;
    context->message_block[63] = (context->length_low) & 0xFF;

    sha1_process_message_block(context);
}

#define SHA1_SIZE   128
char *sha1_hash(const char *source)
{
    sha1_context    sha;
    char            *buf = NULL;

    sha1_reset(&sha);
    sha1_input(&sha, source, strlen(source));

    if (!sha1_result(&sha)){
        printf("SHA1 ERROR: Could not compute message digest");
        return NULL;
    } else {
        buf = (char *)malloc(SHA1_SIZE);
        if (NULL == buf) {
            printf("buf is NULL.\n");
            return NULL;
        }
        memset(buf, 0, sizeof(SHA1_SIZE));
        sprintf(buf, "%08X%08X%08X%08X%08X", sha.message_digest[0],sha.message_digest[1],
                sha.message_digest[2],sha.message_digest[3],sha.message_digest[4]);
        //lr_save_string(buf, lrvar);

        //return strlen(buf);
        return buf;
    }
}

sha1.h

/********************************************************************************
 *      Copyright:  (C) 2017 Yang Zheng
 *                  All rights reserved.
 *
 *       Filename:  sha1.h
 *    Description:  This head file 
 *
 *        Version:  1.0.0(08/17/2017~)
 *         Author:  Yang Zheng 
 *      ChangeLog:  1, Release initial version on "08/17/2017 02:11:08 PM"
 *                 
 ********************************************************************************/
#ifndef __SHA1_H__
#define __SHA1_H__

#include 
#include 
#include 

typedef struct _sha1_context{
	unsigned        message_digest[5];      
	unsigned        length_low;             
	unsigned        length_high;            
	unsigned char   message_block[64]; 
	int             message_block_index;         
	int             computed;                    
	int             corrupted;                   
}sha1_context;

#define SHA1_CIRCULAR_SHIFT(bits,word) ((((word) << (bits)) & 0xFFFFFFFF) | ((word) >> (32-(bits))))

char *sha1_hash(const char *source);
#endif


intlib.c

/*********************************************************************************
 *      Copyright:  (C) 2017 Yang Zheng  
 *                  All rights reserved.
 *
 *       Filename:  intlib.c
 *    Description:  This file 
 *                 
 *        Version:  1.0.0(08/17/2017~)
 *         Author:  Yang Zheng 
 *      ChangeLog:  1, Release initial version on "08/17/2017 02:09:51 PM"
 *                 
 ********************************************************************************/
int tolower(int c) 
{ 
    if (c >= 'A' && c <= 'Z') 
    { 
        return c + 'a' - 'A'; 
    } 
    else 
    { 
        return c; 
    } 
} 

int htoi(const char s[],int start,int len) 
{ 
    int i,j; 
    int n = 0; 
    if (s[0] == '0' && (s[1]=='x' || s[1]=='X')) //判断是否有前导0x或者0X
    { 
        i = 2; 
    } 
    else 
    { 
        i = 0; 
    } 
    i+=start;
    j=0;
    for (; (s[i] >= '0' && s[i] <= '9') 
            || (s[i] >= 'a' && s[i] <= 'f') || (s[i] >='A' && s[i] <= 'F');++i) 
    {   
        if(j>=len)
        {
            break;
        }
        if (tolower(s[i]) > '9') 
        { 
            n = 16 * n + (10 + tolower(s[i]) - 'a'); 
        } 
        else 
        { 
            n = 16 * n + (tolower(s[i]) - '0'); 
        } 
        j++;
    } 
    return n; 
} 





intlib.h

/********************************************************************************
 *      Copyright:  (C) 2017 Yang Zheng
 *                  All rights reserved.
 *
 *       Filename:  intlib.h
 *    Description:  This head file 
 *
 *        Version:  1.0.0(08/17/2017~)
 *         Author:  Yang Zheng 
 *      ChangeLog:  1, Release initial version on "08/17/2017 02:10:46 PM"
 *                 
 ********************************************************************************/

#ifndef __INT_LIB_H__
#define __INT_LIB_H__

int tolower(int c);
int htoi(const char s[],int start,int len);
#endif


base64.c

/*********************************************************************************
 *      Copyright:  (C) 2017 Yang Zheng  
 *                  All rights reserved.
 *
 *       Filename:  base64.c
 *    Description:  This file 
 *                 
 *        Version:  1.0.0(08/17/2017~)
 *         Author:  Yang Zheng 
 *      ChangeLog:  1, Release initial version on "08/17/2017 02:09:12 PM"
 *                 
 ********************************************************************************/
#include "base64.h"

const char base[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; 

char *base64_encode(const char* data, int data_len) 
{ 
    int prepare = 0; 
    int ret_len; 
    int temp = 0; 
    char *ret = NULL; 
    char *f = NULL; 
    int tmp = 0; 
    unsigned char changed[4]; 
    int i = 0; 
    ret_len = data_len / 3; 
    temp = data_len % 3; 

    if (temp > 0) 
        ret_len += 1; 

    ret_len = ret_len*4 + 1; 
    ret = (char *)malloc(ret_len); 
    if ( ret == NULL) { 
        printf("ret alloc failure.\n"); 
        return NULL; 
    } 
    memset(ret, 0, ret_len); 

    f = ret; 
    while (tmp < data_len) 
    { 
        temp = 0; 
        prepare = 0; 
        memset(changed, '\0', 4); 
        while (temp < 3) 
        { 
            if (tmp >= data_len) 
                break; 

            prepare = ((prepare << 8) | (data[tmp] & 0xFF)); 
            tmp++; 
            temp++; 
        } 

        prepare = (prepare<<((3-temp)*8)); 
        for (i=0; i<4 ;i++) { 
            if (temp < i) 
                changed[i] = 0x40; 
            else 
                changed[i] = (prepare>>((3-i)*6)) & 0x3F; 

            *f = base[changed[i]]; 
            f++; 
        } 
    } 
    *f = '\0'; 
      
    return ret; 
} 

static char find_pos(char ch)   
{ 
    char *ptr = (char*)strrchr(base, ch);//the last position (the only) in base[] 
    return (ptr - base); 
} 

char *base64_decode(const char *data, int data_len) 
{ 
    int ret_len = (data_len / 4) * 3; 
    int equal_count = 0; 
    char *ret = NULL; 
    char *f = NULL; 
    int tmp = 0; 
    int temp = 0; 
    char need[3]; 
    int prepare = 0; 
    int i = 0; 

    if (*(data + data_len - 1) == '=') 
        equal_count += 1; 

    if (*(data + data_len - 2) == '=') 
        equal_count += 1; 

    if (*(data + data_len - 3) == '=') 
        equal_count += 1; 

    switch (equal_count) 
    { 
    case 0: 
        ret_len += 4;//3 + 1 [1 for NULL] 
        break; 
    case 1: 
        ret_len += 4;//Ceil((6*3)/8)+1 
        break; 
    case 2: 
        ret_len += 3;//Ceil((6*2)/8)+1 
        break; 
    case 3: 
        ret_len += 2;//Ceil((6*1)/8)+1 
        break; 
    } 
    ret = (char *)malloc(ret_len); 
    if (NULL == ret) { 
        printf("ret alloc failure.\n"); 
        return NULL; 
    } 
    memset(ret, 0, ret_len); 

    f = ret; 
    while (tmp < (data_len - equal_count)) 
    { 
        temp = 0; 
        prepare = 0; 
        memset(need, 0, 4); 
        while (temp < 4) 
        { 
            if (tmp >= (data_len - equal_count)) 
                break; 
            prepare = (prepare << 6) | (find_pos(data[tmp])); 
            temp++; 
            tmp++; 
        } 

        prepare = prepare << ((4-temp) * 6); 
        for (i=0; i<3; i++) { 
            if (i == temp) 
                break; 
            *f = (char)((prepare>>((2-i)*8)) & 0xFF); 
            f++; 
        } 
    } 
    *f = '\0'; 
    return ret; 
}



base64.h

/********************************************************************************
 *      Copyright:  (C) 2017 Yang Zheng
 *                  All rights reserved.
 *
 *       Filename:  base64.h
 *    Description:  This head file 
 *
 *        Version:  1.0.0(08/17/2017~)
 *         Author:  Yang Zheng 
 *      ChangeLog:  1, Release initial version on "08/17/2017 02:11:15 PM"
 *                 
 ********************************************************************************/
#ifndef __BASE64_H__
#define __BASE64_H__
 
#include  
#include 
#include 

char* base64_encode(const char* data, int data_len); 
char *base64_decode(const char* data, int data_len); 
#endif




测试server的方法

下载浏览器websocket插件,我用chrome浏览器测试,具体操作可参考:
linux 下websocket server demo例程_第1张图片



你可能感兴趣的:(C语言,socket,Linux,c,websocket)