C++socket基础进阶:Select与封装

之前用C#做服务器没搞明白于是从笔者比较熟悉的C++开始入手从头学了一遍,整理一下笔记。
资料来源于《网络多人游戏架构与编程》第三章,这本书讲的很明白,比起网上每篇博客都在介绍的原理,这本书更偏向于代码实现。
代码应该没什么大问题,看懂逻辑的话能自己封装的更好。

上一篇介绍Socket基础函数的在这。

阻塞和非阻塞I/O

开启非阻塞模式

默认情况下,socket操作是阻塞模式,但可以由如下函数转变为非阻塞模式,非阻塞模式的socket被要求执行一个需要阻塞的操作时,它将立刻返回-1,还设置了错误代码errno或WSAGetLastError,分别返回EAGAIN或WSAEWOULDBLOCK。这个代码表示之前的socket行为已经阻塞,没有发生就被终止了。

//Windows系统
int ioctlsocket(
	SOCKET s,
	long cmd, //用于控制socket参数,在这种情况下,输入FIONBIO
	u_long* argp //是这个参数的取值,任意非零值将开启非阻塞模式,0将阻止开启
);
//POSIX兼容系统
int fcntl(
    int sock,
    int cmd, //发给socket的命令
    ...
);

在更新的POSIX系统上,必须首先使用F_GETEL获取当前与socket相关的标志,让它们与常数O_NONBLOCK按位或运算之后,使用F_SETEL命令更新socket上的标志。

当socket出于非阻塞模式,调用任何阻塞函数都是安全的,因为我们知道如果它不能在没有阻塞的情况下完成,它会立刻返回。

Select

socket库提供了同时检查多个socket的方式,只要其中有一个socket准备好了就开始执行:

int select(
	int nfds,
	fd_set* readfds,
	fd_set* writefds,
	fd_set* exceptfds,
	const timeval* timeout
);

参数:

  1. nfds在POSIX平台,是待检查的编号最大的socket的标识符。在POSIX平台,每一个socket只是一个整数,所以直接将所有socket的最大值传入这个函数。在Windows平台,socket表示为指针,而不是整数,所以这个参数不起作用,可以忽略。
  2. readfds是指向socket集合的指针,称为fd_set,包含要检查可读性的socket。
  3. writefds是指向fd_set的指针,存储待检查可写性的socket。当select函数返回时,保留在writefds中所有socket都保证可写,不会引起调用线程的阻塞。给writefds传入nullptr来跳过任何socket可写性的检查。通常,只有当socket的输出缓冲区有太多数据时,socket才会阻塞写操作。
  4. excptfds是指向fd_set的指针,这个fd_set存储待检查错误的socket。当select函数返回,保留在exceptfds中的所有socket都已经发生了错误。给exceptfds传入nullptr来跳过任何错误的检查。
  5. timeout是指向超时之前可以等待最长时间的指针。如果在readfds中的任意一个socket可读,writefds中的任意一个socket可写,或者exceptfds中的任意一个socket发生错误之前发生超时,清空所有集合,select函数将控制返回给调用线程。给timeout输入nullptr来表名没有超时限制。

select函数返回执行之后保留在readfds、writefds和exceptfds中socket的数量。如果发生超时,这个值是0.

fd_set:

fd_set myReadSet;
FD_ZERO(&myReadSet);	//初始化一个空的fd_set
FD_SET(mySocket, &myReadSet);	//给fd_set添加一个socket
FD_INSET(mySocket, &myReadSet);	//检查在select函数返回值后,一个socket是否在fd_set中

封装

类型安全的SocketAddress类

class SocketAddress {
public:
	SocketAddress(uint32_t inAddress, uint16_t inPort) {
		GetAsSockAddrIn()->sin_family = AF_INET;
		GetAsSockAddrIn()->sin_addr.S_un.S_addr = htonl(inAddress);
		GetAsSockAddrIn()->sin_port = htons(inPort);
	}
	SocketAddress(const sockaddr& inSockAddr) {
		memcpy(&mSockAddr, &inSockAddr, sizeof(sockaddr));
	}
	size_t GetSize() const { return sizeof(sockaddr); }
private:
	sockaddr mSockAddr;
	sockaddr_in* GetAsSockAddrIn() {
		return reinterpret_cast<sockaddr_in*> (&mSockAddr);
	}
};
typedef std::shared_ptr<SocketAddress> SocketAddressPtr;

使用SocketAddressFactory类的域名解析

class SocketAddressFactory {
public:
	static SocketAddressPtr CreateIPv4FromString(const std::string& inString) {
		auto pos = inString.find_last_of(':');
		std::string host, service;
		if (pos != std::string::npos) {
			host = inString.substr(0, pos);
			service = inString.substr(pos + 1);
		}
		else {
			host = inString;
			//use default port
			service = "0";
		}
		addrinfo hint;
		memset(&hint, 0, sizeof(hint));
		hint.ai_family = AF_INET;
		addrinfo* result;
		int error = getaddrinfo(host.c_str(), service.c_str(), &hint, &result);
		if (error != 0) {
			freeaddrinfo(result);
			return nullptr;
		}
		while (!result->ai_addr && result->ai_next)result = result->ai_next;
		if (!result->ai_addr) {
			freeaddrinfo(result);
			return nullptr;
		}
		auto toRet = std::make_shared<SocketAddress>(*result->ai_addr);
		freeaddrinfo(result);
		return toRet;
	}
};

类型安全的UDP Socket

class UDPSocket {
public:
	~UDPSocket();
	int Bind(const SocketAddress& inToAddress);
	int SendTo(const void* inData, int inLen, const SocketAddress& inTo);
	int ReceiveFrom(void* inBuffer, int inLen, SocketAddress& outFrom);
private:
	UDPSocket(SOCKET inSocket) : mSocket(inSocket) {}
	SOCKET mSocket;
};
typedef std::shared_ptr<UDPSocket> UDPSocketPtr;

int UDPSocket::Bind(const SocketAddress& inBindAddress) {
	int err = bind(mSocket, &inBindAddress.mSockAddr, inBindAddress.GetSize());
	if (err != 0) {
		//return error from UDPSocket::Bind
		return -1;
	}
	return 0;
}
int UDPSocket::SendTo(const void* inData, int inLen, const SocketAddress& inTo) {
	int byteSentCount = sendto(mSocket, static_cast<const char*>(inData), inLen, 0, &inTo.mSockAddr, inTo.GetSize());
	if (byteSentCount >= 0) return byteSentCount;
	//return error from UDPSocket::SendTo
	return -1;
}
int UDPSocket::ReceiveFrom(void* inBuffer, int inLen, SocketAddress& outFrom) {
	int fromLength = outFrom.GetSize();
	int readByteCount = recvfrom(
		mSocket,
		static_cast<char*> (inBuffer),
		inLen,
		0,
		&outFrom.mSockAddr,
		&fromLength
	);
	if (readByteCount >= 0) return readByteCount;
	//return error from UDPSocket::ReceiveFrom
	return -1;
}
UDPSocket::~UDPSocket() {
	closesocket(mSocket);
}

类型安全的TCP Socket

class TCPSocket {
public:
	~TCPSocket();
	int Connect(const SocketAddress& inAddress);
	int Bind(const SocketAddress& inToAddress);
	int Listen(int inBackLog = 32);
	std::shared_ptr<TCPSocket> Accept(SocketAddress& inFromAddress);
	int Send(const void* inData, int inLen);
	int Receive(void* inBuffer, int inLen);
private:
	TCPSocket(SOCKET inSocket) : mSocket(inSocket) {}
	SOCKET mSocket;
};
typedef std::shared_ptr<TCPSocket> TCPSocketPtr;

int TCPSocket::Connect(const SocketAddress& inAddress) {
	int err = connect(mSocket, &inAddress.mSockAddr, inAddress.GetSize());
	if (err < 0) return -1;
	return NO_ERROR;
}
int TCPSocket::Listen(int inBackLog) {
	int err = listen(mSocket, inBackLog);
	if (err < 0)return -1;
	return NO_ERROR;
}
int TCPSocket::Send(const void* inData, int inLen) {
	int bytesSendCount = send(
		mSocket,
		static_cast<const char*> (inData),
		inLen, 0
	);
	if (bytesSendCount < 0)return -1;
	return bytesSendCount;
}
int TCPSocket::Receive(void* inBuffer, int inLen) {
	int bytesReceiveCount = recv(
		mSocket,
		static_cast<char*> (inBuffer),
		inLen, 0
	);
	if (bytesReceiveCount < 0)return -1;
	return bytesReceiveCount;
}
TCPSocket::~TCPSocket() {
	closesocket(mSocket);
}

与类型安全的TCPSocket一起使用的select函数

fd_set* FillSetFromVector(fd_set& outSet,
	const std::vector<TCPSocketPtr>* inSockets) {
	if (inSockets) {
		FD_ZERO(&outSet);
		for (const TCPSocketPtr& socket : *inSockets)
			FD_SET(socket->mSocket, &outSet);
		return &outSet;
	}
	return nullptr;
}

void FillVectorFromSet(
	std::vector<TCPSocketPtr>* outSockets,
	const std::vector<TCPSocketPtr>* inSockets,
	const fd_set& inSet
) {
	if (inSockets && outSockets) {
		outSockets->clear();
		for (const TCPSocketPtr& socket : *inSockets)
			if (FD_ISSET(socket->mSocket, &inSet))
				outSockets->push_back(socket);
	}
}

int Select(
	const std::vector<TCPSocketPtr>* inReadSet,
	std::vector<TCPSocketPtr>* outReadSet,
	const std::vector<TCPSocketPtr>* inWriteSet,
	std::vector<TCPSocketPtr>* outWriteSet,
	const std::vector<TCPSocketPtr>* inExceptSet,
	std::vector<TCPSocketPtr>* outExceptSet
) {
	fd_set read, write, except;
	fd_set* readPtr = FillSetFromVector(read, inReadSet);
	fd_set* writePtr = FillSetFromVector(read, inWriteSet);
	fd_set* exceptPtr = FillSetFromVector(read, inExceptSet);

	int toRet = select(0, readPtr, writePtr, exceptPtr, nullptr);
	if (toRet > 0) {
		FillVectorFromSet(outReadSet, inReadSet, read);
		FillVectorFromSet(outWriteSet, inWriteSet, write);
		FillVectorFromSet(outExceptSet, inExceptSet, except);
	}
	return toRet;
}

运行一个TCP服务器循环

void DoTCPLoop() {
	TCPSocketPtr listenSocket = CreateTCPSocket(INET);
	SocketAddress receivingAddres(INADDR_ANY, 48000);
	if (listenSocket->Bind(receivingAddres) != NO_ERROR)return;
	std::vector<TCPSocketPtr> readBlockSockets;
	readBlockSockets.push_back(listenSocket);
	std::vector<TCPSocketPtr> readableSockets;
	while (true) {
		if (Select(&readBlockSockets, &readableSockets,
			nullptr, nullptr,
			nullptr, nullptr
		)) {
			for (const TCPSocketPtr& socket : readableSockets) {
				if (socket == listenSocket) {
					SocketAddress newCilentAddress;
					TCPSocketPtr newSocket = listenSocket->Accept(newCilentAddress);
					readBlockSockets.push_back(newSocket);
					// 在这了处理新连入 ProcessNewCilent(newSocket, newCilentAddress);
				}
				else {
					char segment[MAX_SEGMENT_SIZE];
					int dataReceived = socket->Receive(segment, MAX_SEGMENT_SIZE);
					if (dataReceived > 0) {
						//在这里处理新数据 ProcessDataFromClient(socket, segment, dataReceived);
					}
				}
			}
		}
	}
}

你可能感兴趣的:(c++,开发语言)