websocket 实在tcp的基础上增加了二次握手,所有其实websocket和之前的iocp流程其实是一样的
1.CSingleton.h
#ifndef CSINGLETON_H
#define CSINGLETON_H
#pragma once
//互斥访问锁
class CThreadLockCs
{
public:
//此函数初始化一个临界区对象。
CThreadLockCs() { InitializeCriticalSection(&m_cs); }
//删除临界区对象
~CThreadLockCs() { DeleteCriticalSection(&m_cs); }
//加锁接下来的代码处理过程不允许其他线程同时操作
void lock() { EnterCriticalSection(&m_cs); }
//解锁解锁 到EnterCriticalSection之间代码资源已经释放了,其他线程可以进行操作
void unlock() { LeaveCriticalSection(&m_cs); }
private:
//临界区对象
CRITICAL_SECTION m_cs;
};
/************************************************************************
singleton模式类模板
1:延迟创建类实例 2:double check 3:互斥访问 4:模板
************************************************************************/
template
class CSingleton
{
private:
static T* _instance;
CSingleton(void);
static CThreadLockCs lcs;
public:
static T* Instance(void);
static void Close(void);
};
//模板类static变量
template
T* CSingleton::_instance = NULL;
template
CThreadLockCs CSingleton::lcs;
//模板类方法实现
template
CSingleton::CSingleton(void)
{
}
template
T* CSingleton::Instance(void)
{
//double-check
//延迟创建,只有调用方访问Instance才会创建类实例
if (_instance == NULL)
{
//互斥访问锁,用CriticalSection实现
lcs.lock();
if (_instance == NULL)
{
_instance = new T;
}
lcs.unlock();
}
return _instance;
}
template
void CSingleton::Close(void)
{
if (_instance)
{
delete _instance;
}
}
#endif
2.CIOCP.h
#ifndef CIOCP_H
#define CIOCP_H
#include
#include
/******************************************************************************
Module: IOCP.h
Notices: Copyright (c) 2007 Jeffrey Richter & Christophe Nasarre
Purpose: This class wraps an I/O Completion Port.
Revise: IOCP封装类,由《windows核心编程》第10章示例程序源码改编所得
******************************************************************************/
#pragma once
class CIOCP
{
private:
HANDLE m_hIOCP; //IOCP句柄
public:
CIOCP(int nMaxConcurrency = -1);
~CIOCP();
//创建IOCP,nMaxConcurrency指定最大线程并发数量,0默认为cpu数量
bool CreateIOCP(int nMaxConcurrency = 0);
//关闭IOCP
bool CloseIOCP();
//为设备关联一个IOCP
bool AsscciateDevice(HANDLE hDevice, ULONG_PTR CompKey);
//为socket关联一个IOCP
bool AsscciateScoket(SOCKET hSocket, ULONG_PTR CompKey);
//为iocp传递事件通知
bool PostStatus(ULONG_PTR CompKey, DWORD dwNumBytes = 0, OVERLAPPED* po = NULL);
//从IO完成队列中获取事件通知。IO完成队列无事件时,该函数将阻塞
bool GetStatus(ULONG_PTR* pCompKey, PDWORD pdwNumBytes, OVERLAPPED** ppo, DWORD dwMilliseconds = INFINITE);
//获取IOCP对象
const HANDLE GetIOCP();
};
#endif // !CIOCP_H
/ End of File /
3.CIOCP.cpp
#include "stdafx.h"
#include "CIOCP.h"
#ifdef _DEBUG
#define ASSERT(T) assert(T)
#else
#define ASSERT(T) (T)
#endif
CIOCP::CIOCP(int nMaxConcurrency)
{
m_hIOCP = NULL;
if (-1 != nMaxConcurrency)
{
CreateIOCP(nMaxConcurrency);
}
}
CIOCP::~CIOCP()
{
if (m_hIOCP != NULL)
ASSERT(CloseHandle(m_hIOCP));
}
//创建IOCP,nMaxConcurrency指定最大线程并发数量,0默认为cpu数量
bool CIOCP::CreateIOCP(int nMaxConcurrency )
{
//创建一个完成端口
m_hIOCP = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, nMaxConcurrency);
//效验
ASSERT(m_hIOCP != NULL);
return (m_hIOCP != NULL);
}
//关闭IOCP
bool CIOCP::CloseIOCP()
{
//关闭完成端口
bool bResult = CloseHandle(m_hIOCP);
m_hIOCP = NULL;
return(bResult);
}
//为设备关联一个IOCP
bool CIOCP::AsscciateDevice(HANDLE hDevice, ULONG_PTR CompKey)
{ //关联完成端口
//1关联的设备句柄2完成端口句柄3需要绑定的结构体
bool fOk = (CreateIoCompletionPort(hDevice, m_hIOCP, CompKey, 0) == m_hIOCP);
//效验
ASSERT(fOk);
return(fOk);
}
//为socket关联一个IOCP
bool CIOCP::AsscciateScoket(SOCKET hSocket, ULONG_PTR CompKey)
{
return AsscciateDevice((HANDLE)hSocket, CompKey);
}
//为iocp传递事件通知
bool CIOCP::PostStatus(ULONG_PTR CompKey, DWORD dwNumBytes , OVERLAPPED* po)
{
//手动添加一个完成端口io操作
bool fOk = PostQueuedCompletionStatus(m_hIOCP, dwNumBytes, CompKey, po);
ASSERT(fOk);
return(fOk);
}
//从IO完成队列中获取事件通知。IO完成队列无事件时,该函数将阻塞
bool CIOCP::GetStatus(ULONG_PTR* pCompKey, PDWORD pdwNumBytes, OVERLAPPED** ppo, DWORD dwMilliseconds)
{
//监控完成端口
//1 我们创建的完成端口 2操作完成后返回的字节数 3需要绑定的结构体
//4重叠结构LPOVERLAPPED 5等待完成端口的超时时间
return(GetQueuedCompletionStatus(m_hIOCP, pdwNumBytes, pCompKey, ppo, dwMilliseconds));
}
//获取IOCP对象
const HANDLE CIOCP::GetIOCP()
{
return m_hIOCP;
}
4.OverlappedIOInfo.h
#ifndef OVERLAPPEDIOINFO_H
#define OVERLAPPEDIOINFO_H
#pragma once
#include
#include
#define MAXBUF 1024*8
/******************************************************************************
Module: OverlappedIOInfo.h
Notices: Copyright (c) 20161201 whg
Purpose:
IOCP网络编程模型中,需要用到GetQueuedCompletionStatus函数获取已完成事件。
但该函数的返回参数无socket或buffer的描述信息。
一个简单的解决办法,创建一个新的结构,该结构第一个参数是OVERLAPPED。
由于AcceptEx、WSASend等重叠IO操作传入的是Overlapped结构体的地址,调用AcceptEx等重叠IO操作,
在Overlapped结构体后面开辟新的空间,写入socket或buffer的信息,即可将socket或buffer的信息由
GetQueuedCompletionStatus带回。
参考《windows核心编程》和CSDN PiggyXP
******************************************************************************/
enum IOOperType {
TYPE_ACP, //accept事件到达,有新连接请求
TYPE_RECV, //数据接收事件
TYPE_SEND, //数据发送事件
TYPE_CLOSE, //关闭事件
TYPE_NO_OPER
};
class COverlappedIOInfo:public OVERLAPPED
{
public:
SOCKET m_sSocket; //套接字
WSABUF m_recvBuf; //接收缓冲区,用于AcceptEx、WSARecv操作
char m_cRecvBuf[MAXBUF];
WSABUF m_sendBuf; //发送缓冲区,用于WSASend操作
char m_cSendBuf[MAXBUF];
sockaddr_in m_addr; //对端地址
public:
COverlappedIOInfo();
~COverlappedIOInfo();
//复位Overlapped
void ResetOverlapped();
//复位RecvBuffer
void ResetRecvBuffer();
//复位SendBuffer
void ResetSendBuffer();
};
#endif // !OVERLAPPEDIOINFO_H
5.OverlappedIOInfo.cpp
#include "stdafx.h"
#include "OverlappedIOInfo.h"
COverlappedIOInfo::COverlappedIOInfo()
{
m_sSocket = INVALID_SOCKET;
ResetOverlapped();
ResetRecvBuffer();
ResetSendBuffer();
}
COverlappedIOInfo::~COverlappedIOInfo()
{
if (m_sSocket != INVALID_SOCKET)
{
closesocket(m_sSocket);
m_sSocket = INVALID_SOCKET;
}
}
void COverlappedIOInfo::ResetOverlapped()
{
Internal = InternalHigh = 0;
Offset = OffsetHigh = 0;
hEvent = NULL;
}
void COverlappedIOInfo::ResetRecvBuffer()
{
ZeroMemory(m_cRecvBuf, MAXBUF);
m_recvBuf.buf = m_cRecvBuf;
m_recvBuf.len = MAXBUF;
}
void COverlappedIOInfo::ResetSendBuffer()
{
ZeroMemory(m_cSendBuf, MAXBUF);
m_sendBuf.buf = m_cSendBuf;
m_sendBuf.len = MAXBUF;
}
6.TaskService.h
#ifndef WHG_CTASKSVC
#define WHG_CTASKSVC
#include
#include
class CTaskService
{
public:
//Activate用于激活一定数量的工作者线程,默认激活数量为1。返回当前线程队列大小
UINT Activate(int num = 1);
//获取线程队列大小
UINT GetThreadsNum(void);
protected:
//只有子类才可以构造父类,拒绝外部访问构造类实例
CTaskService(void);
~CTaskService(void);
//子类应重定义工作线程细节
virtual void svc();
//Close用于等待线程结束并关闭线程,退出线程由子类控制
void Close();
private:
//工作者线程访问接口
static UINT WorkThread(LPVOID param);
//线程队列
std::vector vec_threads;
};
#endif
7.TaskService.cpp
#include "stdafx.h"
#include "TaskService.h"
CTaskService::CTaskService(void)
{
}
CTaskService::~CTaskService(void)
{
Close();
}
UINT CTaskService::Activate(int num)
{
for (int i = 0; i < num; i++)
{
CWinThread* pwt = AfxBeginThread(WorkThread, this, THREAD_PRIORITY_NORMAL, 0, CREATE_SUSPENDED);
if (pwt)
{
pwt->m_bAutoDelete = false;
pwt->ResumeThread();
vec_threads.push_back(pwt);
}
}
return vec_threads.size();
}
UINT CTaskService::GetThreadsNum(void)
{
return vec_threads.size();
}
UINT CTaskService::WorkThread(LPVOID param)
{
CTaskService* pts = (CTaskService*)param;
if (pts)
{
pts->svc();
}
return 0;
}
void CTaskService::svc()
{
}
void CTaskService::Close()
{
int cnt = vec_threads.size();
if (cnt > 0)
{
std::vector::iterator iter = vec_threads.begin();
for (; iter != vec_threads.end(); iter++)
{
CWinThread* pwt = *iter;
WaitForSingleObject(pwt->m_hThread, INFINITE);
delete pwt;
}
vec_threads.clear();
}
}
8.WebSocket.h
#pragma once
#define RESPONSELEN 512 //握手返回
#define ACCEPTKEYLEN 512 //连接密钥
#define PACKDATALEN 1024 //封包数据
#define ACCEPTDATALEN 1024 //建立连接
#define UNPACKDATA 1024 //解包数据
typedef struct SHA1Context {
unsigned Message_Digest[5];
unsigned Length_Low;
unsigned Length_High;
unsigned char Message_Block[64];
int Message_Block_Index;
int Computed;
int Corrupted;
} SHA1Context;
class WebSocket
{
private:
char m_ResponseHeader[RESPONSELEN]; //握手返回
char m_AcceptKey[ACCEPTKEYLEN]; //连接密钥
char m_PackData[PACKDATALEN]; //封包数据
char m_AcceptData[ACCEPTDATALEN]; //建立连接
char m_UnpackData[UNPACKDATA]; //解包数据
public:
WebSocket();
~WebSocket();
public:
//建立连接
bool WebAccept(int sock, char * buf, int len);
//发送消息
int WebSend(int sock, const char* buf, int bufLen);
//接收消息
int WebRecv(char* buf, int bufLen);
private:
//连接密钥
bool GetAcceptKey(int sock,char * buf, int len);
//二次握手
void shakeHand(int connfd, char *serverKey);
//数据封包
char* packData(const char * message, unsigned long * len, unsigned long n);
private:
//将大改小
int tolower(int c);
//类型转换
int htoi(const char s[], int start, int len);
private:
//数据编码
char *base64_encode(const char* data, int data_len);
private:
//初始化SHA1Context
void SHA1Reset(SHA1Context *);
//数据验证
int SHA1Result(SHA1Context *);
//提取数据
void SHA1Input(SHA1Context *, const char *, unsigned int);
//编码算法
void SHA1ProcessMessageBlock(SHA1Context *);
//编码算法
void SHA1PadMessage(SHA1Context *);
//算法入口
char * sha1_hash(const char *source);
};
9.WebSocket.cpp
#include "stdafx.h"
#include "WebSocket.h"
#include
#include
#include
#define SHA1CircularShift(bits,word) ((((word) << (bits)) & 0xFFFFFFFF) | ((word) >> (32-(bits))))
const char base[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";
WebSocket::WebSocket()
{
}
WebSocket::~WebSocket()
{
}
//建立连接
bool WebSocket::WebAccept(int sock,char * buf,int len)
{
if (!GetAcceptKey(sock,buf, len))
return false;
return true;
}
//获取密钥
bool WebSocket::GetAcceptKey( int sock,char * buf,int len)
{
char *flag = "Sec-WebSocket-Key: ";
const char * GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
if (!buf||!len)
return false;
memset(m_AcceptKey, 0, sizeof(m_AcceptKey));
char * keyBegin = strstr((char *)buf, flag);
keyBegin += strlen(flag);
int bufLen = strlen(buf);
for (int i = 0; i payloadLen ? payloadLen : bufLen;
memset(buf, 0, payloadLen);
memcpy(buf, m_UnpackData + 8, payloadLen);
}
else if (payloadLen == 127)
{
char temp[8] = {0};
memcpy(masks, m_UnpackData + 10, 4);
for (int i = 0; i < 8; i++)
{
temp[i] = m_UnpackData[9 - i];
}
unsigned long n = 0;
memcpy(&n, temp, 8);
payloadLen = bufLen > n ? n : bufLen;
memset(buf, 0, payloadLen);
memcpy(buf, m_UnpackData + 14, payloadLen);//toggle error(core dumped) if data is too long.
}
else
{
memcpy(masks, m_UnpackData + 2, 4);
payloadLen = bufLen > payloadLen ? payloadLen : bufLen;
memset(buf, 0, payloadLen);
memcpy(buf, m_UnpackData + 6, payloadLen);
}
for (int i = 0; i < payloadLen; i++)
{
buf[i] = (char)(buf[i] ^ masks[i % 4]);
}
return strlen(buf);
}
//数据封包
char* WebSocket::packData(const char * message, unsigned long * len, unsigned long n)
{
memset(m_PackData, 0, sizeof(m_PackData));
if (n < 126)
{
m_PackData[0] = 0x82;
m_PackData[1] = n;
memcpy(m_PackData + 2, message, n);
*len = n + 2;
}
else if (n < PACKDATALEN)
{
m_PackData[0] = 0x82;
m_PackData[1] = 126;
m_PackData[2] = (n >> 8 & 0xFF);
m_PackData[3] = (n & 0xFF);
memcpy(m_PackData + 4, message, n);
*len = n + 4;
}
else
{
// 暂不处理超长内容
*len = 0;
}
return m_PackData;
}
//发送消息
int WebSocket::WebSend(int sock, const char* buf, int bufLen)
{
if (!sock)
return 0;
unsigned long n = 0;
char * data = packData(buf, &n, bufLen);
if (!data || n <= 0)
return 0;
return send(sock, data, n,0);
}
//将大改小
int WebSocket::tolower(int c)
{
if (c >= 'A' && c <= 'Z')
{
return c + 'a' - 'A';
}
return c;
}
//类型转换
int WebSocket::htoi(const char s[], int start, int len)
{
int i;
int n = 0;
if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) //判断是否有前导0x或者0X
i = 2;
else
i = 0;
i += start;
for (int j = 0; (s[i] >= '0' && s[i] <= '9')
|| (s[i] >= 'a' && s[i] <= 'f') || (s[i] >= 'A' && s[i] <= 'F'); ++i)
{
if (j >= len)
{
break;
}
if (tolower(s[i]) > '9')
{
n = 16 * n + (10 + tolower(s[i]) - 'a');
}
else
{
n = 16 * n + (tolower(s[i]) - '0');
}
j++;
}
return n;
}
//数据编码
char *WebSocket::base64_encode(const char* data, int data_len)
{
int RetLen = data_len / 3;
int temp = data_len % 3;
if (temp > 0)
{
RetLen += 1;
}
RetLen = RetLen * 4 + 1;
char *RetData = (char *)malloc(RetLen);
if (RetData == NULL)
{
printf("No enough memory.\n");
exit(0);
}
memset(RetData, 0, RetLen);
char *RetTemp = RetData;
int tmp = 0;
while (tmp < data_len)
{
temp = 0;
int prepare = 0;
char changed[4] = {0};
while (temp < 3)
{
//printf("tmp = %d\n", tmp);
if (tmp >= data_len)
{
break;
}
prepare = ((prepare << 8) | (data[tmp] & 0xFF));
tmp++;
temp++;
}
prepare = (prepare << ((3 - temp) * 8));
//printf("before for : temp = %d, prepare = %d\n", temp, prepare);
for (int i = 0; i < 4; i++)
{
if (temp < i)
{
changed[i] = 0x40;
}
else
{
changed[i] = (prepare >> ((3 - i) * 6)) & 0x3F;
}
*RetTemp = base[changed[i]];
//printf("%.2X", changed[i]);
RetTemp++;
}
}
*RetTemp = '\0';
return RetData;
}
void WebSocket::SHA1Reset(SHA1Context * context)
{
context->Length_Low = 0;
context->Length_High = 0;
context->Message_Block_Index = 0;
context->Message_Digest[0] = 0x67452301;
context->Message_Digest[1] = 0xEFCDAB89;
context->Message_Digest[2] = 0x98BADCFE;
context->Message_Digest[3] = 0x10325476;
context->Message_Digest[4] = 0xC3D2E1F0;
context->Computed = 0;
context->Corrupted = 0;
}
int WebSocket::SHA1Result(SHA1Context * context)
{
if (context->Corrupted)
return 0;
if (!context->Computed) {
SHA1PadMessage(context);
context->Computed = 1;
}
return 1;
}
void WebSocket::SHA1Input(SHA1Context * context, const char *message_array, unsigned int length)
{
if (!length) return;
if (context->Computed || context->Corrupted) {
context->Corrupted = 1;
return;
}
while (length-- && !context->Corrupted) {
context->Message_Block[context->Message_Block_Index++] = (*message_array & 0xFF);
context->Length_Low += 8;
context->Length_Low &= 0xFFFFFFFF;
if (context->Length_Low == 0) {
context->Length_High++;
context->Length_High &= 0xFFFFFFFF;
if (context->Length_High == 0) context->Corrupted = 1;
}
if (context->Message_Block_Index == 64) {
SHA1ProcessMessageBlock(context);
}
message_array++;
}
}
void WebSocket::SHA1ProcessMessageBlock(SHA1Context * context)
{
const unsigned K[] = { 0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6 };
int t;
unsigned temp;
unsigned W[80];
unsigned A, B, C, D, E;
for (t = 0; t < 16; t++) {
W[t] = ((unsigned)context->Message_Block[t * 4]) << 24;
W[t] |= ((unsigned)context->Message_Block[t * 4 + 1]) << 16;
W[t] |= ((unsigned)context->Message_Block[t * 4 + 2]) << 8;
W[t] |= ((unsigned)context->Message_Block[t * 4 + 3]);
}
for (t = 16; t < 80; t++) W[t] = SHA1CircularShift(1, W[t - 3] ^ W[t - 8] ^ W[t - 14] ^ W[t - 16]);
A = context->Message_Digest[0];
B = context->Message_Digest[1];
C = context->Message_Digest[2];
D = context->Message_Digest[3];
E = context->Message_Digest[4];
for (t = 0; t < 20; t++) {
temp = SHA1CircularShift(5, A) + ((B & C) | ((~B) & D)) + E + W[t] + K[0];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = SHA1CircularShift(30, B);
B = A;
A = temp;
}
for (t = 20; t < 40; t++) {
temp = SHA1CircularShift(5, A) + (B ^ C ^ D) + E + W[t] + K[1];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = SHA1CircularShift(30, B);
B = A;
A = temp;
}
for (t = 40; t < 60; t++) {
temp = SHA1CircularShift(5, A) + ((B & C) | (B & D) | (C & D)) + E + W[t] + K[2];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = SHA1CircularShift(30, B);
B = A;
A = temp;
}
for (t = 60; t < 80; t++) {
temp = SHA1CircularShift(5, A) + (B ^ C ^ D) + E + W[t] + K[3];
temp &= 0xFFFFFFFF;
E = D;
D = C;
C = SHA1CircularShift(30, B);
B = A;
A = temp;
}
context->Message_Digest[0] = (context->Message_Digest[0] + A) & 0xFFFFFFFF;
context->Message_Digest[1] = (context->Message_Digest[1] + B) & 0xFFFFFFFF;
context->Message_Digest[2] = (context->Message_Digest[2] + C) & 0xFFFFFFFF;
context->Message_Digest[3] = (context->Message_Digest[3] + D) & 0xFFFFFFFF;
context->Message_Digest[4] = (context->Message_Digest[4] + E) & 0xFFFFFFFF;
context->Message_Block_Index = 0;
}
void WebSocket::SHA1PadMessage(SHA1Context * context)
{
if (context->Message_Block_Index > 55) {
context->Message_Block[context->Message_Block_Index++] = 0x80;
while (context->Message_Block_Index < 64) context->Message_Block[context->Message_Block_Index++] = 0;
SHA1ProcessMessageBlock(context);
while (context->Message_Block_Index < 56) context->Message_Block[context->Message_Block_Index++] = 0;
}
else {
context->Message_Block[context->Message_Block_Index++] = 0x80;
while (context->Message_Block_Index < 56) context->Message_Block[context->Message_Block_Index++] = 0;
}
context->Message_Block[56] = (context->Length_High >> 24) & 0xFF;
context->Message_Block[57] = (context->Length_High >> 16) & 0xFF;
context->Message_Block[58] = (context->Length_High >> 8) & 0xFF;
context->Message_Block[59] = (context->Length_High) & 0xFF;
context->Message_Block[60] = (context->Length_Low >> 24) & 0xFF;
context->Message_Block[61] = (context->Length_Low >> 16) & 0xFF;
context->Message_Block[62] = (context->Length_Low >> 8) & 0xFF;
context->Message_Block[63] = (context->Length_Low) & 0xFF;
SHA1ProcessMessageBlock(context);
}
char * WebSocket::sha1_hash(const char *source)
{
SHA1Context sha;
char *buf;//[128];
SHA1Reset(&sha);
SHA1Input(&sha, source, strlen(source));
if (!SHA1Result(&sha))
{
printf("SHA1 ERROR: Could not compute message digest");
return NULL;
}
else
{
buf = (char *)malloc(128);
memset(buf, 0, sizeof(buf));
sprintf(buf, "%08X%08X%08X%08X%08X", sha.Message_Digest[0], sha.Message_Digest[1],
sha.Message_Digest[2], sha.Message_Digest[3], sha.Message_Digest[4]);
return buf;
}
return NULL;
}
10.Server.h
#ifndef SERVER_H
#define SERVER_H
#pragma once
#include "TaskService.h"
#include "OverlappedIOInfo.h"
#include "CSingleton.h"
#include "CIOCP.h"
#include "WebSocket.h"
class CServer :public CTaskService
{
#define ACCEPT_SOCKET_NUM 10
private:
WSAData m_wsaData; //winsock版本类型
SOCKET m_sListen; //端口监听套接字
std::vector m_vecAcps; //等待accept的套接字
WebSocket m_WebSocket; //网页长连接
//已建立连接的信息,每个结构含有一个套接字、发送缓冲和接收缓冲,以及对端地址
std::vector m_vecContInfo;
//操作vector的互斥访问锁
CThreadLockCs m_lsc;
//IOCP封装类
CIOCP m_iocp;
//AcceptEx函数指针
LPFN_ACCEPTEX m_lpfnAcceptEx;
//GetAcceptSockAddrs函数指针
LPFN_GETACCEPTEXSOCKADDRS m_lpfnGetAcceptSockAddrs;
public:
CServer(void);
~CServer(void);
bool StartListen(unsigned short port, std::string ip);
protected:
virtual void svc();
private:
//启动CPU*2个线程,返回已启动线程个数
UINT StartThreadPull();
//获取AcceptEx和GetAcceptExSockaddrs函数指针
bool GetLPFNAcceptEXAndGetAcceptSockAddrs();
//利用AcceptEx监听accept请求
bool PostAccept(COverlappedIOInfo* ol);
//处理accept请求,NumberOfBytes=0表示没有收到第一帧数据,>0表示收到第一帧数据
bool DoAccept(COverlappedIOInfo* ol, DWORD NumberOfBytes = 0);
//投递recv请求
bool PostRecv(COverlappedIOInfo* ol);
//处理recv请求
bool DoRecv(COverlappedIOInfo* ol);
//从已连接socket列表中移除socket及释放空间
bool DeleteLink(SOCKET s);
//释放3个部分步骤:
//1:清空IOCP线程队列,退出线程
//2: 清空等待accept的套接字m_vecAcps
//3: 清空已连接的套接字m_vecContInfo并清空缓存
void CloseServer();
};
typedef CSingleton SERVER;
#endif
11.Server.cpp
#include "stdafx.h"
#include "Server.h"
CServer::CServer()
{
m_lpfnAcceptEx = NULL;
m_lpfnGetAcceptSockAddrs = NULL;
WSAStartup(MAKEWORD(2, 2), &m_wsaData);
printf("%d\n", m_wsaData.iMaxSockets);
}
CServer::~CServer()
{
CloseServer();
WSACleanup();
}
bool CServer::StartListen(unsigned short port, std::string ip)
{
//listen socket需要将accept操作投递到完成端口,因此,listen socket属性必须有重叠IO
m_sListen = WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED);
if (m_sListen == INVALID_SOCKET)
{
cout << "WSASocket create socket error" << endl;
return false;
}
//创建并设置IOCP并发线程数量
if (m_iocp.CreateIOCP() == FALSE)
{
cout << "IOCP create error,error code " << WSAGetLastError() << endl;
return false;
}
//将listen socket绑定至iocp
if (!m_iocp.AsscciateScoket(m_sListen, TYPE_ACP))
{
cout << "iocp Associate listen Socket error" << endl;
return false;
}
sockaddr_in service;
service.sin_family = AF_INET;
service.sin_port = htons(port);
if (ip.empty())
{
service.sin_addr.s_addr = INADDR_ANY;
}
else
{
service.sin_addr.s_addr = inet_addr(ip.c_str());
}
if (bind(m_sListen, (sockaddr*)&service, sizeof(service)) == SOCKET_ERROR)
{
cout << "bind() error,error code " << WSAGetLastError() << endl;
return false;
}
cout << "bind ok!" << endl;
if (listen(m_sListen, SOMAXCONN) == SOCKET_ERROR)
{
cout << "listen() error,error code " << WSAGetLastError() << endl;
return false;
}
cout << "listen ok!" << endl;
//启动工作者线程
int threadnum = StartThreadPull();
cout << "启动工作者线程,num=" << threadnum << endl;
//获取AcceptEx和GetAcceptSockAddrs函数指针
if (!GetLPFNAcceptEXAndGetAcceptSockAddrs())
{
return false;
}
//创建10个acceptex
for (int i = 0; i < ACCEPT_SOCKET_NUM; i++)
{
//用accept
COverlappedIOInfo* ol = new COverlappedIOInfo;
if (!PostAccept(ol))
{
delete ol;
return false;
}
}
}
void CServer::svc()
{
while (true)
{
DWORD NumberOfBytes = 0;
unsigned long CompletionKey = 0;
OVERLAPPED* ol = NULL;
if (FALSE != GetQueuedCompletionStatus(m_iocp.GetIOCP(), &NumberOfBytes, &CompletionKey, &ol, WSA_INFINITE))
{
COverlappedIOInfo* olinfo = (COverlappedIOInfo*)ol;
if (CompletionKey == TYPE_CLOSE)
{
break;
}
if (NumberOfBytes == 0 && (CompletionKey == TYPE_RECV || CompletionKey == TYPE_SEND))
{
//客户端断开连接
cout << "客户端断开连接,ip=" << inet_ntoa(olinfo->m_addr.sin_addr) << ",port=" << olinfo->m_addr.sin_port << endl;
DeleteLink(olinfo->m_sSocket);
continue;
}
switch (CompletionKey)
{
case TYPE_ACP:
{
DoAccept(olinfo, NumberOfBytes);
PostAccept(olinfo);
}
break;
case TYPE_RECV:
{
DoRecv(olinfo);
PostRecv(olinfo);
}
break;
case TYPE_SEND:
{
}
break;
default:
break;
}
}
else
{
int res = WSAGetLastError();
switch (res)
{
case ERROR_NETNAME_DELETED:
{
COverlappedIOInfo* olinfo = (COverlappedIOInfo*)ol;
if (olinfo)
{
cout << "客户端异常退出,ip=" << inet_ntoa(olinfo->m_addr.sin_addr) << ",port=" << olinfo->m_addr.sin_port << endl;
DeleteLink(olinfo->m_sSocket);
}
}
break;
default:
cout << "workthread GetQueuedCompletionStatus error,error code " << WSAGetLastError() << endl;
break;
}
continue;
}
}
cout << "workthread stop" << endl;
}
//启动CPU*2个线程,返回已启动线程个数
UINT CServer::StartThreadPull()
{
//获取系统cpu个数启动线程
SYSTEM_INFO si;
GetSystemInfo(&si);
//启动cpu数量*2个线程
return Activate(si.dwNumberOfProcessors * 2);
}
//获取AcceptEx和GetAcceptExSockaddrs函数指针
bool CServer::GetLPFNAcceptEXAndGetAcceptSockAddrs()
{
DWORD BytesReturned = 0;
//获取AcceptEx函数指针
GUID GuidAcceptEx = WSAID_ACCEPTEX;
if (SOCKET_ERROR == WSAIoctl(
m_sListen,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&GuidAcceptEx,
sizeof(GuidAcceptEx),
&m_lpfnAcceptEx,
sizeof(m_lpfnAcceptEx),
&BytesReturned,
NULL, NULL))
{
cout << "WSAIoctl get AcceptEx function error,error code " << WSAGetLastError() << endl;
return false;
}
//获取GetAcceptexSockAddrs函数指针
GUID GuidGetAcceptexSockAddrs = WSAID_GETACCEPTEXSOCKADDRS;
if (SOCKET_ERROR == WSAIoctl(
m_sListen,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&GuidGetAcceptexSockAddrs,
sizeof(GuidGetAcceptexSockAddrs),
&m_lpfnGetAcceptSockAddrs,
sizeof(m_lpfnGetAcceptSockAddrs),
&BytesReturned,
NULL, NULL))
{
cout << "WSAIoctl get GetAcceptexSockAddrs function error,error code " << WSAGetLastError() << endl;
return false;
}
return true;
}
//利用AcceptEx监听accept请求
bool CServer::PostAccept(COverlappedIOInfo* ol)
{
if (m_lpfnAcceptEx == NULL)
{
cout << "m_lpfnAcceptEx is NULL" << endl;
return false;
}
SOCKET s = ol->m_sSocket;
ol->ResetRecvBuffer();
ol->ResetOverlapped();
ol->ResetSendBuffer();
ol->m_sSocket = WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED);
if (ol->m_sSocket == INVALID_SOCKET)
{
cout << "WSASocket error ,error code " << WSAGetLastError() << endl;
return false;
}
//这里建立的socket用来和对端建立连接,终会加入m_vecContInfo列表
//调用acceptex将accept socket绑定至完成端口,并开始进行事件监听
//这里需要传递Overlapped,new一个COverlappedIOInfo
//AcceptEx是m_listen的监听事件,m_listen已经绑定了完成端口;虽然ol->m_sSock已经创建,
//但未使用,现在不必为ol->m_sSock绑定完成端口。在AcceptEx事件发生后,再为ol->m_sSock绑定IOCP
DWORD byteReceived = 0;
if (FALSE == m_lpfnAcceptEx(
m_sListen,
ol->m_sSocket,
ol->m_recvBuf.buf,
ol->m_recvBuf.len - (sizeof(SOCKADDR_IN) + 16) * 2,
sizeof(SOCKADDR_IN) + 16,
sizeof(SOCKADDR_IN) + 16,
&byteReceived,
ol))
{
DWORD res = WSAGetLastError();
if (ERROR_IO_PENDING != res)
{
cout << "AcceptEx error , error code " << res << endl;
return false;
}
}
std::vector::iterator iter = m_vecAcps.begin();
for (; iter != m_vecAcps.end(); iter++)
{
if (*iter == s)
{
*iter = ol->m_sSocket;
}
}
if (iter == m_vecAcps.end())
{
m_vecAcps.push_back(ol->m_sSocket);
}
return true;
}
//处理accept请求,NumberOfBytes=0表示没有收到第一帧数据,>0表示收到第一帧数据
bool CServer::DoAccept(COverlappedIOInfo* ol, DWORD NumberOfBytes )
{
//分支用于获取远端地址。
//如果接收TYPE_ACP同时收到第一帧数据,则第一帧数据内包含远端地址。
//如果没有收到第一帧数据,则通过getpeername获取远端地址
SOCKADDR_IN* ClientAddr = NULL;
int remoteLen = sizeof(SOCKADDR_IN);
if (NumberOfBytes > 0)
{
//接受的数据分成3部分,第1部分是客户端发来的数据,第2部分是本地地址,第3部分是远端地址。
if (m_lpfnGetAcceptSockAddrs)
{
SOCKADDR_IN* LocalAddr = NULL;
int localLen = sizeof(SOCKADDR_IN);
m_lpfnGetAcceptSockAddrs(
ol->m_recvBuf.buf,
ol->m_recvBuf.len - (sizeof(SOCKADDR_IN) + 16) * 2,
sizeof(SOCKADDR_IN) + 16,
sizeof(SOCKADDR_IN) + 16,
(LPSOCKADDR*)&LocalAddr,
&localLen,
(LPSOCKADDR*)&ClientAddr,
&remoteLen);
cout << "收到新的连接请求,ip=" << inet_ntoa(ClientAddr->sin_addr) << ",port=" << ClientAddr->sin_port <<
"数据为:" << ol->m_recvBuf.buf << endl;
if (!m_WebSocket.WebAccept(ol->m_sSocket, ol->m_recvBuf.buf, ol->m_recvBuf.len))
{
cout <<"websockt连接失败"<< endl;
return false;
}
}
}
else if (NumberOfBytes == 0)
{
//未收到第一帧数据
if (SOCKET_ERROR == getpeername(ol->m_sSocket, (sockaddr*)ClientAddr, &remoteLen))
{
cout << "getpeername error,error code " << WSAGetLastError() << endl;
}
else
{
cout << "收到新的连接请求,ip=" << inet_ntoa(ClientAddr->sin_addr) << ",port=" << ClientAddr->sin_port << endl;
}
}
COverlappedIOInfo* pol = new COverlappedIOInfo;
pol->m_sSocket = ol->m_sSocket;
pol->m_addr = *ClientAddr;
//服务端只收取recv,同时监听recv和send可用设计位偏移,用或运算实现
if (m_iocp.AsscciateScoket(pol->m_sSocket, TYPE_RECV))
{
PostRecv(pol);
m_vecContInfo.push_back(pol);
}
else
{
delete pol;
return false;
}
return true;
}
//投递recv请求
bool CServer::PostRecv(COverlappedIOInfo* ol)
{
DWORD BytesRecvd = 0;
DWORD dwFlags = 0;
ol->ResetOverlapped();
ol->ResetRecvBuffer();
int recvnum = WSARecv(ol->m_sSocket, &ol->m_recvBuf, 1, &BytesRecvd, &dwFlags, (OVERLAPPED*)ol, NULL);
if (recvnum != 0)
{
int res = WSAGetLastError();
if (WSA_IO_PENDING != res)
{
cout << "WSARecv error,error code " << res << endl;
}
}
return true;
}
//处理recv请求
bool CServer::DoRecv(COverlappedIOInfo* ol)
{
int ret= m_WebSocket.WebRecv(ol->m_recvBuf.buf, strlen(ol->m_recvBuf.buf));
cout << "收到客户端数据:ip=" << inet_ntoa(ol->m_addr.sin_addr) << ",port=" << ol->m_addr.sin_port <<
";内容=" << ol->m_recvBuf.buf << endl;
struct data
{
int a;
char b[20] = {0};
long c;
};
data aa;
aa.a = 5;
strcat(aa.b, "hello Websocket");
aa.c = 314159;
memcpy(ol->m_cSendBuf, &aa,sizeof(data));
m_WebSocket.WebSend(ol->m_sSocket, ol->m_cSendBuf, sizeof(data));
return true;
}
//从已连接socket列表中移除socket及释放空间
bool CServer::DeleteLink(SOCKET s)
{
m_lsc.lock();
std::vector::iterator iter = m_vecContInfo.begin();
for (; iter != m_vecContInfo.end(); iter++)
{
if (s == (*iter)->m_sSocket)
{
COverlappedIOInfo* ol = *iter;
closesocket(s);
m_vecContInfo.erase(iter);
delete ol;
break;
}
}
m_lsc.unlock();
return true;
}
//释放3个部分步骤:
//1:清空IOCP线程队列,退出线程
//2: 清空等待accept的套接字m_vecAcps
//3: 清空已连接的套接字m_vecContInfo并清空缓存
void CServer::CloseServer()
{
//1:清空IOCP线程队列,退出线程,有多少个线程发送多少个PostQueuedCompletionStatus信息
int threadnum = GetThreadsNum();
for (int i = 0; i < threadnum; i++)
{
if (FALSE == m_iocp.PostStatus(TYPE_CLOSE))
{
cout << "PostQueuedCompletionStatus error,error code " << WSAGetLastError() << endl;
}
}
//2:清空等待accept的套接字m_vecAcps
std::vector::iterator iter = m_vecAcps.begin();
for (; iter != m_vecAcps.end(); iter++)
{
SOCKET s = *iter;
closesocket(s);
}
m_vecAcps.clear();
//3:清空已连接的套接字m_vecContInfo并清空缓存
std::vector::iterator iter2 = m_vecContInfo.begin();
for (; iter2 != m_vecContInfo.end(); iter2++)
{
COverlappedIOInfo* ol = *iter2;
closesocket(ol->m_sSocket);
iter2 = m_vecContInfo.erase(iter2);
delete ol;
}
m_vecContInfo.clear();
}