c/c++自定义通讯协议(TCP/UDP)

 前言:TCP与UDP是大家耳熟能详的两种传输层通信协议,本质区别在于传输控制策略不相同:使用TCP协议,可以保证传输层数据包能够有序地被接受方接收到,依赖其内部一系列复杂的机制,比如握手协商,ACK确认,超时重传,拥塞控制等; 而UDP基本上没有额外的控制策略,所以接收方能不能接收到传输层数据包是无法保证的。正是因为不能保证每一个数据包有序到达,UDP数据包与包之间,必须是相互独立的,每一个都应该是有意义的可以被解析出完整应用层报文的数据块,因此UDP又被称为面向(单个)报文的协议;而每一个TCP数据包则可以是应用层报文的某一部分,多个有序的数据包就可以拼接出完整的应用层报文,因此TCP被称作面向流的协议。

        我们知道,网络层(即IP层)数据包是有最大长度MTU限制的(因为物理层大包丢包概率很高),所以不论是发送UDP包还是TCP包,如果突破了该限制,数据包将会被IP层切片,接收方的IP层会根据分片id对传输层数据片段的进行重组,分片和重组都会占用cpu和内存资源,严重降低通讯效率。如果通信双方采用TCP通信,在握手连接阶段会协商MSS,即一个TCP最大包含的数据量,有了MSS约定,TCP层交付给IP层的数据包就不会超过IP层的MTU限制,也就是说分片工作在TCP传输层完成。而使用UDP时,一旦UDP数据包被IP层分片,接收方大概率是无法组成完整的UDP数据包的,因为就算某些片段丢失了,发送方也不会对整个UDP包进行重发,因此UDP通讯是禁止IP层分片的(一旦超过MTU,会直接丢弃)。

        所以,如果我们的应用需要传输大的数据包,就没办法使用单纯的UDP协议传输了,除非基于UDP在应用层自行实现一种类似于TCP内部的分片控制机制,完成数据的可靠传输。

一. 自定义网络协议

        假设我们的应用层报文一般比较小,不超过底层的MTU限制,这样一来,我们既可以使用TCP,也可以使用UDP来进行传输。下面是一个最简单的协议定义示例,包含消息头定义和消息体定义:

struct SmHeader
{
    int m_length;// 消息头长度+消息体长度
    int m_request_type;//请求类型
    int m_reply_type;//响应类型
    int m_body_type;//消息体类型
};

struct Body1
{
  int m_int_b1;
  float m_float_b1;
  char m_char_b1;
  char m_reserve[3];//字节对齐
};

struct Body2
{
  int m_int_array_b2[12];
  float m_float_array_b2[15];
};

struct Body3
{
  char m_char_array_b3[512];
};

struct Body1Assemble
{
    int m_count;
    Body1 m_b1[0];
};

union SmBody
{
  Body1 b1;
  Body2 b2;  
  Body3 b3;
  Body1Assemble b4;
};

struct SmMessage
{
  SmHeader head;
  SmBody body;
};

由上述协议的定义可知,这样一个完整的消息最少有sizeof(SmHeader) = 16字节(一个消息可以没有消息体,比如PING/PONG心跳包,只有消息头即可),由于Body1Assemble类型的数据体长度不确定,因此用TCP的话,可以传递很长的消息。

二. 消息的接收

由于发送方的发送速率与接收方的接收速率很难匹配,在接收方的接收缓冲区内会形成数据包累积,所以我们需要上述定义的消息头协助完成数据包的提取,有效处理接收端粘包问题。

char buffer[1024];//在应用层定义一个数据缓冲区,至少能够放得下最大的数据包
int bfsize = 1024;//缓冲区长度
int legacy_bytes = 0;//上一次解析处理剩余的字节数
bool skip_recv = false;//是否可以直接使用上一次剩余数据解析出完整数据

while(1)
{
    int current_size = 0;
    if(!skip_recv)
    {
        current_size = recv(fd,buffer+legacy_bytes,bfsize-legacy_bytes);
        if(current_size<=0)
            break;
        current_size += legacy_bytes;
        legacy_bytes = 0;
    }
    else
    {
        current_size = legacy_bytes;
        legacy_bytes = 0;
        skip_recv = false;
    }
    int expected_size = -1;
    if(current_size>=sizeof(SmHeader))
    {
        const SmHeader* head = (SmHeader*)buffer;
        expected_size = head->m_length;
        if(expected_size<=0 || expected_size>bfsize)
        {
            printf("Invalid message header or buffer insufficient.\r\n");
        }
        else if(current_size>=expected_size)
        {
            /***********process a complete message******/
            // 解析buffer中的消息;
            SmMessage* msg = (SmMessage*)buffer;
            // 处理消息 balabala...
            if(head->m_body_type == BODY1ASSEMBLE)
            {
                struct Body1* data;
                for(int i=0;ibody.b4.m_count;i++)
                {
                    data = &msg->body.b4.m_b1[i];
                    //printf("data id: %d, data member1: %d",i,data->m_int_b1);
                }
            }
            /********************end********************/
            if(current_size>expected_size)
            {
                legacy_bytes = current_size-expected_size;
                memmove(buffer,buffer+expected_size,legacy_bytes);
                if(legacy_bytes>=sizeof(SmHeader))
                {
                    const SmHeader* next_head = (SmHeader*)buffer;
                    int next_expected_size = next_head->m_length;
                    if(next_expected_size>0)
                        skip_recv = legacy_bytes>=next_expected_size;
                }
            }
            else
            {
                legacy_bytes = 0;
            }
        }
        else
        {
            legacy_bytes = current_size;
            printf("Incompleted message.\r\n");
        }
    }
    else
    {
        legacy_bytes = current_size;
    }
}

三. python版本的TCP/UDP报文收发

import socket
import struct
import copy
import threading

buffer = bytearray()
buffer_size = 0
skip_recv = False  #之前遗留的数据是否可以直接解出完整的数据包
legacy_Bytes_count=0  #之前遗留的数据长度


def case_reply_1():
    fmt = '5ifc'
    return fmt


def case_reply_2():
    fmt = '16i15f'
    return fmt


def case_reply_3():
    fmt = '4i512c'
    return fmt

def case_reply_4():
    fmt = '5i'
    return fmt


def case_reply_default():
    print("No implementation for processing this type of message.")
    return None


def case_request_0():
    reqType = 0
    repType = -1
    msgType = 1
    """
        struct RequestBody1
        {
            char name[8];
            int idx;
            float account;
        }
    """
    name = b"Jhon\n\n\n\n"
    idx = 101
    account = 1361.58
    length = 16 + len(name) + 4 + 4
    fmt = "4i8sif" 
    st = struct.Struct(fmt)
    data = st.pack(length,reqType,repType,msgType,name,idx,account)
    return data

def case_request_1():
    reqType = 1
    repType = -1
    msgType = 2
    """
        struct RequestBody2
        {
            bool female;
            char reserve[3];
            int age;
        }
    """
    female = True
    age = 25
    length = 16 + 8
    fmt = "4i?3xi"
    st = struct.Struct(fmt)
    data = st.pack(length,reqType,repType,msgType,female,age)
    return data


def case_request_default():
    data = bytes()
    return data


def unpack_messages(num:int, msg:bytearray)->tuple:
    numbers = {
        0:   case_reply_1,
        1:   case_reply_2,
        2:   case_reply_3,
        3:   case_reply_4
    }
    method = numbers.get(num, case_reply_default)
    if method:
        fmt = method()
        cnt = 0
        if num == 3:
            cnt = struct.calcsize(fmt)
            if len(msg) < cnt:
                return tuple()
            assemble = struct.unpack(fmt,msg[:cnt])
            extra_fmt = 'ifc3s'
            fmt += assemble[4]*extra_fmt
        if fmt:
            cnt = struct.calcsize(fmt)
        # python的struct字节对齐和c/c++中的字节对齐处理方式不同,可能由于字节对齐的问题解析出错,
        # 此时可以根据msg内部字节排列情况,逐个解析
        if cnt != len(msg):
            print("message parse error.")
            return tuple()
        
        structured_msg = struct.unpack(fmt,msg)
        return structured_msg
    else:
        return tuple()


def connect_to_motion_server(IP:str, PORT:int, LPORT:int=-1):
    client = None
    try:
        client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if LPORT > 0:
            client.bind(('',LPORT))#客户端绑定本地端口
        client.connect((IP, PORT))
        print('Connect Success!')
    except socket.error as msg:
        print(msg)
        client.close()
        client = None
    return client


def pack_messages(num:int)->bytes:
    numbers = {
        0:   case_request_0,
        1:   case_request_1
    }
    method = numbers.get(num, case_request_default)
    if method:
        data = method()
        return data
    else:
        return bytes()


def send_message_to_server(_client, msgtyp:int, scktype:str="TCP", address: tuple = None):
    data = pack_messages(msgtyp)
    size = send_to_socket(_client, data, scktype, address)
    return size


def send_to_socket(_client, data: bytes, sock_type: str = 'TCP', address: tuple = None):
    size = 0
    if len(data) < 1:
        return size
    if sock_type == 'TCP':
        size = _client.send(data)
    elif sock_type == 'UDP':
        if address:
            size = _client.sendto(data, address)
        else:
            pass
    return size

    
def process_message(c):
    while True:
        leng, packet = get_reply_from_server(c)
        if leng < 0:
            break
        print("receive message: ",packet)
    return


def get_reply_from_server(c) -> (int, tuple):
    global buffer, buffer_size, skip_recv, legacy_Bytes_count
    pkt_len, pkt, skip_recv,legacy_Bytes_count = read_from_socket(c, buffer,buffer_size,legacy_Bytes_count,skip_recv,
                                                                    min_size=16,min_fmt='iiii')
    return pkt_len, pkt


def read_from_socket(_client, recv_buffer:bytearray, bfsize:int, legacy_size:int, skip_flag:bool, sock_type:str='TCP', min_size=4,min_fmt='i'):
    assert bfsize > 0 and legacy_size >= 0, "please initialize recv_buffer first."
    complete_pkg = None
    header = tuple()
    if not skip_flag:
        extra_to_read = bfsize - legacy_size
        recv_buffer[legacy_size:bfsize] = b'\x00'
        tmp = None
        if sock_type == 'TCP':
            tmp = _client.recv(extra_to_read)
        else:
            tmp, addr = _client.recvfrom(extra_to_read)
        csize = len(tmp)
        if csize <= 0:
            return -1, None, False, 0 # socket disconnect
        csize += legacy_size
        recv_buffer[legacy_size:csize] = tmp
        legacy_size = 0
    else:
        csize = legacy_size
        legacy_size = 0
        skip_flag = False

    if csize >= min_size:
        hst = struct.Struct(min_fmt)
        prefix = hst.unpack(recv_buffer[:min_size])
        if len(prefix) == 0 or prefix[0] <= 0:
            print("Invalid message header, drop it.")
        elif csize >= prefix[0]:
            esize = prefix[0]
            complete_pkg = copy.deepcopy(recv_buffer[:esize])
            if csize > esize:
                legacy_size = csize - esize
                recv_buffer[:legacy_size] = recv_buffer[esize:csize]
                if legacy_size >= min_size:
                    next_prefix = hst.unpack(recv_buffer[:min_size])
                    skip_flag = legacy_size >= next_prefix[0]
            else:
                legacy_size = 0
            header = prefix
        else:
            legacy_size = csize
            print("Incomplete message. Ignore it.")
    else:
        legacy_size = csize
    if header and complete_pkg:
        # header = (length, reqType, resType, msgType)
        st_data = unpack_messages(header[3], complete_pkg)
        return len(complete_pkg), st_data, skip_flag, legacy_size
    else:
        return 0, None, skip_flag, legacy_size


def communicate_buffer_init(size):
    global buffer,buffer_size, skip_recv, legacy_Bytes_count
    buffer_size = size
    skip_recv = False
    legacy_Bytes_count = 0
    buffer = buffer.zfill(buffer_size)


if __name__ == '__main__':
    communicate_buffer_init(1024)
    cli = connect_to_motion_server("127.0.0.1", 12289)
    _inputText = 'c'
    if cli:
        motion_thread = threading.Thread(target=process_message, args=(cli,))
        motion_thread.start()
        while _inputText!= 'q':
            _inputText = input(r'please type a request:').lower()
            sz = 0
            if _inputText == '0':
                sz = send_message_to_server(cli,0)
            if _inputText == '1':
                sz = send_message_to_server(cli,1)
            if sz < 0:
                break
            
        motion_thread.join()
    

四.补充一个c++实现的UDP收发类

/***************************************udpclient.h******************************/

#ifndef UDP_CLIENT_H
#define UDP_CLIENT_H
#include 

class Clientudp
{
public:
	Clientudp();
	bool InitializeClient(const char* ip, int local_port, int remote_port = -1);
	void SetReadTimeout(int timeout_ms);
	int PushToWriteBuffer(const char* msg, unsigned int size);
	int PullFromReadBuffer(char* msg, unsigned int size);
	bool GetClientStatus();
    bool GetRemoteAddress(sockaddr_in* addr);
	~Clientudp();
private:
	int m_fd_;
	int m_read_timeout = -1;
	bool m_socket_avaliable_ = false;
	sockaddr_in m_remote_addr_;
	int m_remote_port_ = -1;
	sockaddr_in m_bind_addr_;
};

#endif



/*************************************udpclient.cpp***************************/

#include 
#include "UdpClient.h"

Clientudp::Clientudp()
{}

Clientudp::~Clientudp()
{
    if(m_socket_avaliable_)
        closesocket(m_fd_);
}

void Clientudp::SetReadTimeout(int timeout_ms)
{
	m_read_timeout = timeout_ms;
}

bool Clientudp::InitializeClient(const char* ip, int local_port, int remote_port)
{
	if (m_socket_avaliable_)
	{
		closesocket(m_fd_);
		m_socket_avaliable_ = false;
	}
	if ((m_fd_ = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) < 0)
	{
		// create socket failed.
		m_socket_avaliable_ = false;
		return m_socket_avaliable_;
	}
	if (m_read_timeout > 0)
	{
		struct timeval read_timeout = { m_read_timeout, 0 };
		setsockopt(m_fd_, SOL_SOCKET, SO_RCVTIMEO, (char *)&read_timeout, sizeof(struct timeval));
	}
	memset(&m_remote_addr_, 0, sizeof(m_remote_addr_));
	m_remote_addr_.sin_family = AF_INET;
	m_remote_addr_.sin_addr.s_addr = inet_addr(ip);
	m_remote_addr_.sin_port = htons(remote_port);
	m_remote_port_ = remote_port;
	memset(&m_bind_addr_, 0, sizeof(m_bind_addr_));
	m_bind_addr_.sin_family = AF_INET;
	m_bind_addr_.sin_addr.s_addr = htonl(INADDR_ANY);
	m_bind_addr_.sin_port = htons(local_port);

	if (bind(m_fd_, (SOCKADDR*)&m_bind_addr_, sizeof(m_bind_addr_)) == SOCKET_ERROR)
	{
		int error_code = WSAGetLastError();
		if (error_code == WSAEADDRINUSE)
			printf("The port %d on this machine has been Occqupied.\n", local_port);
		printf("Bind Client to Fixed address Failed, Then you should settle a correct remote port and send data to remote machine before recv.\n");
		if (m_remote_port_ <= 0)
		{
			printf("InitializeClient Failed, Remote port = %d seems not to be a valid port number.\n", m_remote_port_);
			closesocket(m_fd_);
			m_socket_avaliable_ = false;
			return false;
		}
	}
	m_socket_avaliable_ = true;
	return m_socket_avaliable_;
}

int Clientudp::PushToWriteBuffer(const char* msg, unsigned int size)
{
	if (!m_socket_avaliable_)
	{
		return -1;
	}
	int ret = sendto(m_fd_, msg, size, 0, (sockaddr*)&m_remote_addr_, sizeof(m_remote_addr_));
	if (ret == SOCKET_ERROR)
	{
		int error_code = WSAGetLastError();
		//printf("error_code: %d\n",error_code);
		if (m_remote_port_ < 0 && error_code == WSAEINVAL)
		{
			printf("You should recv data from peer firstly before send.\n");
			return 0;
		}
		closesocket(m_fd_);
		m_socket_avaliable_ = false;
		m_remote_port_ = -1;
		return -1;
	}
	else
	{
		return ret;
	}
}

int Clientudp::PullFromReadBuffer(char* msg, unsigned int size)
{
	if (!m_socket_avaliable_)
	{
		return -1;
	}
	int addr_len = sizeof(m_remote_addr_);
	int ret = recvfrom(m_fd_, msg, size, 0, (sockaddr*)&m_remote_addr_, &addr_len);
	if (ret == SOCKET_ERROR)
	{
		int error_code = WSAGetLastError();
		if (error_code == WSAEMSGSIZE || error_code == WSAEINTR || error_code == WSAETIMEDOUT || error_code == WSAEWOULDBLOCK)
		{
            if(error_code == WSAEMSGSIZE)//datagram is too large to put into msg buffer.
                memset(msg,0x00,size);
			return 0;
		}
		else
		{
			closesocket(m_fd_);
			m_socket_avaliable_ = false;
			return -1;
		}
	}
	else
	{
		return ret;
	}
}

bool Clientudp::GetClientStatus()
{
	return m_socket_avaliable_;
}

bool Clientudp::GetRemoteAddress(sockaddr_in* addr)
{
    if(!m_socket_avaliable_)
        return false;
    if (m_remote_port_ <= 0)
    {
        int asize;
        getpeername(m_fd_, (sockaddr*)&m_remote_addr_, &asize);
        m_remote_port_ = ntohs(m_remote_addr_.sin_port);
        if (m_remote_port_ <= 0)
            return false;
    }
    *addr = m_remote_addr_;
    return true;
}

udp测试:

#include 
#include "udpclient.h"

int main()
{
    WSADATA ws
    WSAStartup(MAKEWORD(2,2),&ws);
    Clientudp udp;
    udp.InitializeClient("127.0.0.1", 1234, -1);
    char RxBuffer[1024];
    char TxBuffer[1024];
    memset(RxBuffer,0x00,1024);
    memset(TxBuffer,0x00,1024);
    /*如果remote_port为-1,则目的端口号未知,只能先收再发。如果知道对方端口号,则收发顺序可随便*/
    int readLen= udp.PullFromReadBuffer(RxBuffer,1024);
    printf("Read message length: %d",readLen);
    sprintf_s(TxBuffer,1023,"hello world");
    udp.PushToWriteBuffer(TxBuffer,1024);
    getchar();
    WSACleanup();   
    return 0;
}

你可能感兴趣的:(网络,tcp/ip,udp,python)