下面是主文件完整代码:
代码符合c99标准可在linux环境下使用gcc编译通过,其他编译器请自行测试 支持同一IP使用不同端口建立多次连接,但不支同一IP同一端口建立多次连接否则服务端会报错。
#include
#include
#include
#include
#include
#include
#include
#include
#include "base64.h"
#include "sha1.h"
#include "intlib.h"
#define MAX_EVENTS 10240 //最大连接数
#define PER_LINE_MAX 256 //client key值最大长度
#define REQUEST_LEN_MAX 2048 //request包的最大字节数
#define DEFEULT_SERVER_PORT 8000 //程序默认使用端口可更改
#define WEB_SOCKET_KEY_LEN_MAX 256 //websocket key值最大长度
#define RESPONSE_HEADER_LEN_MAX 2048 //response包的最大字节数
/*
* Function Name: extract_client_key
* Description: 提取客户端发送的handshake key值
* Input Args: @buffer 客户端发送的握手数据
* Output Args: 输出客户端发来handshake key
* Return Value: server_key 客户端发来的handshake key
*/
static char *extract_client_key(const char * buffer)
{
char *key = NULL; //保存key值
char *start = NULL; // 要提取字符串的起始地址
char *flag = "Sec-WebSocket-Key: ";
int i = 0; //初始化循环使用变量
int buf_len = 0; //初始化buffer长度
if(NULL == buffer) {
printf("buffer is NULL.\n");
return NULL;
}
key=(char *)malloc(WEB_SOCKET_KEY_LEN_MAX); //分配内存
if (NULL == key) {
printf("key alloc failure.\n");
return NULL;
}
memset(key,0, WEB_SOCKET_KEY_LEN_MAX); //将key重置为0
start = strstr(buffer, flag); //获取flag在buffer中的起始位置
if(NULL == start) {
printf("start is NULL.\n");
return NULL;
}
start += strlen(flag); //将指针移至key起始位置
buf_len = strlen(buffer); //获取buffer长度
for(i=0;i 0x7F){ //判断mask标志是否为1,0没有掩码,1有掩码
printf("no mask.\n");
return NULL; //没有掩码则丢弃
}
payloadLen = buffer[1] & 0x7F; //获取payloadLen数值
if (payloadLen == 126) { //如果payloadLen为126则真实长度为buffer第3,4字节
memcpy(masks, buffer+4, 4); //获取掩码(payloadLen结束后跟4字节mask)
payloadLen =(buffer[2]&0xFF) << 8 | (buffer[3]&0xFF); //(将buffer第3字节与0xFF进行与运算)后左移8位在进行或运算(buffer第4字节与0xFF进行与运算)
payload_data=(char *)malloc(payloadLen); //给payload_data分配内存
memset(payload_data,0,payloadLen); //将payload_data重置为0
memcpy(payload_data,buffer+4+strlen(masks),payloadLen); //获取buffer第8(2+2+4)字节之后的内容(数据部分)
} else if (payloadLen == 127) { //如果payloadLen为126则真实长度为buffer第3-10字节
memcpy(masks,buffer+10,4); //获取掩码(payloadLen结束后跟4字节mask)
for ( i = 0; i < 8; i++)
temp[i] = buffer[9 - i]; //获取buffer数据长度(第3-10字节)
memcpy(&n,temp,8); //将数据长度赋值给n
payload_data=(char *)malloc(n); //给payload_data分配内存
memset(payload_data,0,n); //将payload_data重置为0
memcpy(payload_data,buffer+10+strlen(masks),n); //将buffer第14(2+8+4)字节之后的n字节内容赋值给payload_data
payloadLen=n; //设置payloadLen为n
} else { //如果payloadLen为0-125则payloadLen为真实数据长度
memcpy(masks,buffer+2,4); //获取掩码(payloadLen结束后跟4字节mask)
payload_data=(char *)malloc(payloadLen); //给payload_data分配内存
memset(payload_data,0,payloadLen); //将payload_data重置为0
memcpy(payload_data,buffer+2+strlen(masks),payloadLen); //将buffer第6(2+4)字节之后的n字节内容赋值给payload_data
}
for (i = 0; i < payloadLen; i++)
payload_data[i] = (char)(payload_data[i] ^ masks[i % 4]); //将数据与掩码进行异或运算,获得原始数据
printf("data(%ld):\n%s\n\n", payloadLen, payload_data);
return payload_data;
} /* ----- End of deal_data() ----- */
/*
* Function Name: construct_packet_data
* Description: 组建websocket数据包
* Input Args: @message 发送的数据
* @len 发送数据长度
* Output Args: 无
* Return Value: data 返回组建后的包
*/
static char *construct_packet_data(const char *message, unsigned long *len)
{
char *data = NULL;
unsigned long n;
if (NULL == message) { //判断message是否为空
printf("message is NULL.\n");
return NULL;
}
n = strlen(message); //获取message长度
if (n < 126) { //判断n是否小于126,小于126则payload len长度位7位
data=(char *)malloc(n+2); //给data分配内存
if (NULL == data) { //判断data是否为NULL
printf("data is NULL.\n");
return NULL;
}
memset(data,0,n+2); //重置data为0
data[0] = 0x81; //设置第0-7位为1000 0001(FIN为1,Opcode为1)
data[1] = n; //设置第8位为0,9-15位为n(第8位为mask,9-15位为数据长度,客户端发送mask为1,服务端发送mask为0)
memcpy(data+2,message,n); //将message添加到第2个字节之后
*len=n+2; //将指针指向message首地址
} else if (n < 0xFFFF) { //当n小于0xFFFF则为126,后2字节为真实长度
data=(char *)malloc(n+4); //给data分配内存
if (NULL == data) { //判断data是否为NULL
printf("data is NULL.\n");
return NULL;
}
memset(data,0,n+4); //重置data为0
data[0] = 0x81; //设置第0-7位为1000 0001(FIN为1,Opcode为1)
data[1] = 126; //设置第8-15位为0111 1110
data[2] = (n>>8 & 0xFF); //设置第16-23位为n-128(将n右移8位在与1111 1111做与运算)
data[3] = (n & 0xFF); //设置第24-31位为n的右8(0-7)位
memcpy(data+4,message,n); //将message添加到第4个字节之后
*len=n+4; //将指针指向message首地址
} else { //当n大于0xFFFF则payload len前7位为127,后8字节为真实长度
data=(char *)malloc(n+10); //给data分配内存
if (NULL == data) { //判断data是否为NULL
printf("data is NULL.\n");
return NULL;
}
memset(data,0,n+10); //重置data为0
data[0] = 0x81; //设置第0-7位为1000 0001(FIN为1,Opcode为1)
data[1] = 127; //设置第8-15位为0111 1111
data[2] = (n>>56 & 0xFF); //设置第16-23位为n-128(将n右移8位在与1111 1111做与运算)
data[3] = (n>>48 & 0xFF); //设置第24-31位为n-128(将n右移8位在与1111 1111做与运算)
data[4] = (n>>40 & 0xFF); //设置第32-39位为n-128(将n右移8位在与1111 1111做与运算)
data[5] = (n>>32 & 0xFF); //设置第40-47位为n-128(将n右移8位在与1111 1111做与运算)
data[6] = (n>>24 & 0xFF); //设置第48-55位为n-128(将n右移8位在与1111 1111做与运算)
data[7] = (n>>16 & 0xFF); //设置第56-63位为n-128(将n右移8位在与1111 1111做与运算)
data[8] = (n>>8 & 0xFF); //设置第64-71位为n-128(将n右移8位在与1111 1111做与运算)
data[9] = (n & 0xFF); //设置第72-79位为n的右8(0-7)位
memcpy(data+10,message,n); //将message添加到第10个字节之后
*len=n+10; //将指针指向message首地址
}
return data;
} /* ----- End of construct_packet_data() ----- */
/*
* Function Name: response
* Description: 响应客户端
* Input Args: @conn_fd 连接句柄
* @message 发送的数据
* Output Args: 无
* Return Value: 无
*/
void response(int conn_fd, const char *message)
{
char *data = NULL;
unsigned long n=0;
if(!conn_fd) { //判断套接字是否错误
printf("conn_fd is error.\n");
return ;
}
if(NULL == message) { //判断message是否为NULL
printf("message is NULL.\n");
return ;
}
data = construct_packet_data(message, &n); //传入message获取data(数据段)
if(NULL == data || n <= 0) //判断data是否为NULL或数据长度是否为0
{
printf("data is empty!\n");
return ;
}
write(conn_fd, data, n); //将数据写入套接字
if (NULL != data) { //如果data不为NULL则释放内存并赋值为NULL
free(data);
data = NULL;
}
} /* ----- End of response() ----- */
/*
* Function Name:main
* Description:主函数
* Input Args:prot 监听端口(不输默认为8000)
* Output Args:无
* Return Value:无
*/
int main(int argc, char *argv[])
{
int n; //客户端发送数据长度
int conn_fd; //要读取的socket文件描述符
int listen_fd; //服务端套接字
int port = DEFEULT_SERVER_PORT; //初始化默认端口为8000
char *data = NULL; //保存最终数据的指针
char buf[REQUEST_LEN_MAX]; //声明存储缓冲区大小2048
char str[INET_ADDRSTRLEN]; //存储客户端IP
char *sec_websocket_key = NULL; //保存服务端key的指针
struct sockaddr_in servaddr; //初始化sockaddr_in结构体变量
struct sockaddr_in cliaddr; //初始化sockaddr_in结构体变量
socklen_t cliaddr_len; //存储客户端套接字长度
if(argc > 1) //argc: 整数,用来统计你运行程序时送给main函数的命令行参数的个数
port = atoi(argv[1]); //argv[0]指向程序运行的全路径名,argv[n]指向在DOS命令行中执行程序名后的第n个字符串
if(port<=0 || port>0xFFFF) { //判断用户输入端口是否超出(1-65535)范围(0-1023为保留端口,不建议使用)
printf("Port(%d) is out of range(1-%d)\n", port, 0xFFFF);
return -1;
}
/*
*创建套接字(IP地址类型AF_INET为ipv4,AF_INET6为ipv6;
*数据传输方式TCP为SOCK_STREAM,UDP为SOCK_DGRAM;
*IPPROTO_TCP:IPPROTO_UDP:0:传入0系统自动选择传输协议
*/
listen_fd = socket(AF_INET, SOCK_STREAM, 0);
if(listen_fd == -1){ //正常返回0,异常-1
printf("创建套接字失败!\n");
return -1;
}
memset(&servaddr, 0, sizeof(servaddr)); //servaddr每个字节都用0填充
servaddr.sin_family = AF_INET; //使用IPv4地址
/*
*htonl将32位的主机字节顺序转化为32位的网络字节顺序,ip地址是32位的;
* htons将16位的主机字节顺序转化为32位的网络字节顺序,端口号是16位的
* inet_addr("127.0.0.1")将一个十进制的数转化为二进制的数,多用于ipv4的IP转化
* inet_ntoa(servaddr.sin_addr.s_addr)输出IP地址127.0.0.1
*/
servaddr.sin_addr.s_addr = htonl(INADDR_ANY); //INADDR_ANY,所有网卡地址
servaddr.sin_port = htons(port); //端口;
bind(listen_fd, (struct sockaddr *)&servaddr, sizeof(servaddr)); //将套接字和IP、端口绑定,正常返回0,异常-1
listen(listen_fd, 50); //监听套接字,backlog 为请求队列的最大长度
cliaddr_len = sizeof(cliaddr); //cliaddr客户端套接字长度
printf("Listen %d\nAccepting connections ...\n",port); //打印正在监听的端口
int epoll_fd=epoll_create(MAX_EVENTS); //创建一个epoll句柄
if(epoll_fd==-1) //判断句柄是否创建成功
{
perror("epoll_create failed\n");
exit(EXIT_FAILURE);
}
struct epoll_event ev; //epoll事件结构体
struct epoll_event events[MAX_EVENTS]; //事件监听队列
ev.events=EPOLLIN|EPOLLET; //表示对应的文件描述符可读(包括对端SOCKET正常关闭)
ev.data.fd=listen_fd; //将listen_fd设置为要读取的文件描述符
if(epoll_ctl(epoll_fd,EPOLL_CTL_ADD,listen_fd,&ev)==-1) //注册新的listen_fd到epoll_fd中
{
perror("epll_ctl:servaddr register failed\n");
exit(EXIT_FAILURE);
}
int nfds; //epoll监听事件发生的个数
while(1) //循环接受客户端请求
{
nfds=epoll_wait(epoll_fd,events,MAX_EVENTS,-1); //等待事件发生
if(nfds==-1)
{
perror("start epoll_wait failed\n");
continue; //跳过当次循环
}
int i;
for(i=0;i