socketserver模块使用与源码分析

socketserver模块使用与源码分析

前言

  在前面的学习中我们其实已经可以通过socket模块来建立我们的服务端,并且还介绍了关于TCP协议的粘包问题。但是还有一个非常大的问题就是我们所编写的Server端是不支持并发性服务的,在我们之前的代码中只能加入一个通信循环来进行排队式的单窗口一对一服务。那么这一篇文章将主要介绍如何使用socketserver模块来建立具有并发性的Server端。

 

基于TCP协议的socketserver服务端

  我们先看它的一段代码,对照代码来看功能。

#!/usr/bin/env python3
# _*_ coding:utf-8 _*_

# ==== 使用socketserver创建支持多并发性的服务器 TCP协议 ====

import socketserver

class MyServer(socketserver.BaseRequestHandler):
    """自定义类"""

    def handle(self):
        """handle处理请求"""
        print("双向链接通道建立完成:", self.request)  # 对于TCP协议来说,self.request相当于双向链接通道conn,即accept()的第一部分
        print("客户端的信息是:", self.client_address)  # 对于TCP协议来说,相当于accept()的第二部分,即客户端的ip+port

        while 1:  # 开始内层通信循环
            try:  # # bug修复:针对windows环境
                data = self.request.recv(1024)

                if not data:
                    break  # bug修复:针对类UNIX环境

                print("收到客户机[{0}]的消息:[{1}]".format(self.client_address, data))
                self.request.sendall(data.upper())  # #sendall是重复调用send.

            except Exception as e:
                break

        self.request.close()  # 当出现异常情况下一定要关闭链接


if __name__ == '__main__':
    s1 = socketserver.ThreadingTCPServer(("0.0.0.0", 6666), MyServer) # 公网服务器绑定 0.0.0.0 私网测试为 127.0.0.1
    s1.serve_forever()  # 启动服务

 

  1.导入socketserver模块

  2.创建一个新的类,并继承socketserver.BaseRequestHandler

  3.覆写handle方法,对于TCP协议来说,self.request 相当于双向链接通道connself.client_address相当于被服务方的ip和port信息,也就是addr,而整个handle方法相当于链接循环。

  4.写入收发逻辑规则

  5.防止客户端发送空的消息已致双方卡死

  6.防止客户端突然断开已致服务端崩溃

  7.粘包优化(可选)

  8.实例化 socketserver.ThreadingTCPServer类,并传入IP+port,以及刚写好的类名

  9.使用socketserver.ThreadingTCPServer实例化对象中的server_forever( )方法启动服务

 

  它其实是这样的:

    我们不用管链接循环,因为在执行handle方法之前内部已经帮我们做好了。当我们使用serve_forever()方法的时候便开始监听链接描述符对象,一旦有链接请求就创建一个子线程来处理该链接。

 

 

基于UDP协议的socketserver服务端

  基于UDP协议的socketserver服务端与基于TCP协议的socketserver服务端大相径庭,但是还是有几点不太一样的地方。

 

  对TCP来说:

    self.request = 双向链接通道(conn)

  对UDP来说:

    self.request = (client_data_byte,udp的套接字对象)

 

#!/usr/bin/env python3
# _*_ coding:utf-8 _*_

# ==== 使用socketserver创建支持多并发性的服务器 UDP协议 ====

import socketserver

class MyServer(socketserver.BaseRequestHandler):
    """自定义类"""

    def handle(self):
        """handle处理请求"""

        # 由于UDP是基于消息的协议,故根本不用通信循环

        data = self.request[0]  # 对于UDP协议来说,self.request其实是个元组。第一个元素是消息内容主题(Bytes类型),相当于recvfrom()的第一部分
        server = self.request[1]    # 第二个元素是服务端本身,即自己

        print("客户端的信息是:", self.client_address)   # 对于UDP协议来说,相当于recvfrom()的第二部分,即客户端的ip+port

        print("收到客户机[{0}]的消息:[{1}]".format(self.client_address, data))
        server.sendto(data.upper(),self.client_address)


if __name__ == '__main__':
    s1 = socketserver.ThreadingUDPServer(("0.0.0.0", 6666), MyServer)  # 公网服务器绑定 0.0.0.0 私网测试为 127.0.0.1
    s1.serve_forever()  # 启动服务

 

扩展:socketserver源码分析

探索socketserver中的继承关系


  好了,接下来我们开始剖析socketserver模块中的源码部分。在Pycharm下使用CTRL+鼠标左键,可以进入源码进行查看。

  我们在查看源码前一定要首先要明白两点:

 

  socketserver类分为两部分,其一是server类主要是负责处理链接方面,另一类是request类主要负责处理通信方面。

 

  好了,请在脑子里记住这个概念。我们来看一些socketserver模块的实现用了哪些其他的基础模块。

 

  注意,接下来的源码注释部分我并没有在源代码中修改,也请读者不要修改源代码的任何内容。

import socket  # 这模块挺熟悉吧
import selectors  # 这个是一个多线程模块,主要支持I/O多路复用。
import os # 老朋友了
import sys  # 老朋友
import threading  # 多线程模块
from io import BufferedIOBase  # 读写相关的模块
from time import monotonic as time  # 老朋友time模块
socketserver中用到的基础模块

 

  好了,让我们接着往下走。可以看到一个变量__all__,是不是觉得很熟悉?就是我们使用 from xxx import xxx 能导入进的东西全是被__all__控制的,我们看一下它包含了哪些内容。

__all__ = ["BaseServer", "TCPServer", "UDPServer",
           "ThreadingUDPServer", "ThreadingTCPServer",
           "BaseRequestHandler", "StreamRequestHandler",
           "DatagramRequestHandler", "ThreadingMixIn"]
           
# 这个是我们原本的 __all__ 中的值。
if hasattr(os, "fork"):
    __all__.extend(["ForkingUDPServer","ForkingTCPServer", "ForkingMixIn"])
if hasattr(socket, "AF_UNIX"):
    __all__.extend(["UnixStreamServer","UnixDatagramServer",
                    "ThreadingUnixStreamServer",
                    "ThreadingUnixDatagramServer"])
                    
# 上面两个if判断是给__all__添加内容的,os.fork()这个方法是创建一个新的进程,并且只在类UNIX平台下才有效,Windows平台下是无效的,所以这里对于Windows平台来说就from socketserver import xxx 肯定少了三个类,这三个类的作用我们接下来会聊到。而关于socket中的AF_UNIX来说我们其实已经学习过了,是基于文件的socket家族。这在Windows上也是不支持的,只有在类UNIX平台下才有效。所以Windows平台下的导入又少了4个类。
​
​
# poll/select have the advantage of not requiring any extra file descriptor,
# contrarily to epoll/kqueue (also, they require a single syscall).
if hasattr(selectors, 'PollSelector'):
    _ServerSelector = selectors.PollSelector
else:
    _ServerSelector = selectors.SelectSelector
    
# 这两个if还是做I/O多路复用使用的,Windows平台下的结果是False,而类Unix平台下的该if结果为True,这关乎I/O多路复用的性能选择。到底是select还是poll或者epoll。
socketserver模块对于from xxx import * 导入的处理

 

  我们接着向下看源码,会看到许许多多的类。先关掉它来假设自己是解释器一行一行往下走会去执行那个部分。首先是一条if判断

if hasattr(os, "fork"):
    class ForkingMixIn:
        pass # 这里我自己省略了
 
 # 我们可以看见这条代码是接下来执行的,它意思还是如果在类Unix环境下,则会去创建该类。如果在Windows平台下则不会创建该类
处理点一

 

  继续走,其实这种if判断再创建类的地方还有两处。我这里全部列出来:

if hasattr(os, "fork"):
    class ForkingUDPServer(ForkingMixIn, UDPServer): pass
    class ForkingTCPServer(ForkingMixIn, TCPServer): pass
 

if hasattr(socket, 'AF_UNIX'):
​
    class UnixStreamServer(TCPServer):
        address_family = socket.AF_UNIX
​
    class UnixDatagramServer(UDPServer):
        address_family = socket.AF_UNIX
​
    class ThreadingUnixStreamServer(ThreadingMixIn, UnixStreamServer): passclass ThreadingUnixDatagramServer(ThreadingMixIn, UnixDatagramServer): pass
处理点二 and 三

 

  好了,说完了大体粗略的一个流程,我们该来研究这里面的类都有什么作用,这里可以查看每个类的文档信息。大致如下:

 

  前面已经说过,socketserver模块中主要分为两大类,我们就依照这个来进行划分。

 

socketserver模块源码内部class功能一览 
处理链接相关  
BaseServer 基础链接类
TCPServer TCP协议类
UDPServer UDP协议类
UnixStreamServer 文件形式字节流类
UnixDatagramServer 文件形式数据报类
处理通信相关  
BaseRequestHandler 基础请求处理类
StreamRequestHandler 字节流请求处理类
DatagramRequestHandler 数据报请求处理类
多线程相关  
ThreadingMixIn 线程方式
ThreadingUDPServer 多线程UDP协议服务类
ThreadingTCPServer 多线程TCP协议服务类
多进程相关  
ForkingMixIn 进程方式
ForkingUDPServer 多进程UDP协议服务类
ForkingTCPServer 多进程TCP协议服务类

 

  他们的继承关系如下:

ForkingUDPServer(ForkingMixIn, UDPServer)
​
ForkingTCPServer(ForkingMixIn, TCPServer)
​
ThreadingUDPServer(ThreadingMixIn, UDPServer)
​
ThreadingTCPServer(ThreadingMixIn, TCPServer)
​
StreamRequestHandler(BaseRequestHandler)
​
DatagramRequestHandler(BaseRequestHandler)

 

  处理链接相关

socketserver模块使用与源码分析_第1张图片

处理通信相关

socketserver模块使用与源码分析_第2张图片

多线程相关

socketserver模块使用与源码分析_第3张图片

 

总继承关系(处理通信相关的不在其中,并且不包含多进程)

socketserver模块使用与源码分析_第4张图片

 

  最后补上一个多进程的继承关系,就不放在总继承关系中了,容易图形造成混乱。

 

多进程相关

socketserver模块使用与源码分析_第5张图片

 

实例化过程分析


  有了继承关系我们可以来模拟实例化的过程,我们以TCP协议为准:

 

socketserver.ThreadingTCPServer(("0.0.0.0", 6666), MyServer)

 

  我们点进(选中上面代码的ThradingTCPServer部分,CTRL+鼠标左键)源码部分,查找其 __init__ 方法:

class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass

 

  看来没有,那么就找第一父类有没有,我们点进去可以看到第一父类ThreadingMixIn也没有__init__方法,看上面的继承关系图可以看出是普通多继承,那么就是广度优先的查找顺序。我们来看第二父类TCPServer中有没有,看来第二父类中是有__init__方法的,我们详细来看。

class TCPServer(BaseServer):

   """注释全被我删了,影响视线"""

    address_family = socket.AF_INET  #  基于网络的套接字家族

    socket_type = socket.SOCK_STREAM  # TCP(字节流)协议

    request_queue_size = 5  # 消息队列最大为5,可以理解为backlog,即半链接池的大小

    allow_reuse_address = False  # 端口重用默认关闭

    def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
        """Constructor.  May be extended, do not override."""
        BaseServer.__init__(self, server_address, RequestHandlerClass)
        self.socket = socket.socket(self.address_family,
                                    self.socket_type)
                                    
       # 可以看见,上面先是调用了父类的__init__方法,然后又实例化出了一个socket对象!所以我们先不着急往下看,先看其父类中的__init__方法。
       
        if bind_and_activate:
            try:
                self.server_bind()
                self.server_activate()
            except:
                self.server_close()
                raise
TCPServer中的__init__()

 

  来看一下,BaseServer类中的__init__方法。

class BaseServer:

    """注释依旧全被我删了"""

    timeout = None  # 这个变量可以理解为超时时间,先不着急说他。先看 __init__ 方法

    def __init__(self, server_address, RequestHandlerClass):
        """Constructor.  May be extended, do not override."""
        self.server_address = server_address  # 即我们传入的 ip+port ("0.0.0.0", 6666)
        self.RequestHandlerClass = RequestHandlerClass  # 即我们传入的自定义类 MyServer
        self.__is_shut_down = threading.Event()  # 这里可以看到执行了该方法,这里先不详解,因为它是一个事件锁,所以不用管
        self.__shutdown_request = False 
BaseServer中的__init__()

 

  BaseServer中执行了thrading模块下的Event()方法。我这里还是提一嘴这个方法是干嘛用的,它会去控制线程的启动顺序,这里实例化出的self.__is_shut_down其实就是一把锁,没什么深究的,接下来的文章中我也会写到。我们继续往下看,现在是该回到TCPServer__init__方法中来了。

class TCPServer(BaseServer):

   """注释全被我删了,影响视线"""
   
    address_family = socket.AF_INET  #  基于网络的套接字家族

    socket_type = socket.SOCK_STREAM  # TCP(字节流)协议

    request_queue_size = 5  # 消息队列最大为5,可以理解为backlog,即半链接池的大小
    
    allow_reuse_address = False  # 端口重用默认关闭

    def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True): # 看这里!!!!
    """Constructor.  May be extended, do not override."""
        BaseServer.__init__(self, server_address, RequestHandlerClass)
        self.socket = socket.socket(self.address_family,
                                self.socket_type)                        
   
        if bind_and_activate:  # 在创建完socket对象后就会进行该判断。默认参数bind_and_activate就是为True
            try:
                self.server_bind() # 现在进入该方法查看细节
                self.server_activate()
            except:
                self.server_close()
                raise
TCPServer中的__init__()
 
 

  好了,需要找这个self.bind()方法,还是从头开始找。实例本身没有,第一父类ThreadingMixIn也没有,所以现在我们看的是TCPServerserver_bind()方法:

def server_bind(self):
    """Called by constructor to bind the socket.

    May be overridden.

    """
    if self.allow_reuse_address:  # 这里的变量对应 TCPServer.__init__ 上面定义的类方法,端口重用这个。由于是False,所以我们直接往下执行。
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    self.socket.bind(self.server_address)  # 绑定 ip+port 即 ("0.0.0.0", 6666)
    self.server_address = self.socket.getsockname() # 获取socket的名字 其实还是 ("0.0.0.0", 6666)
TCPServer中的server_bind()

 

  现在我们该看TCPServer下的server_activate()方法了。

def server_activate(self):
    """Called by constructor to activate the server.

    May be overridden.

    """
    self.socket.listen(self.request_queue_size)  # 其实就是监听半链接池,backlog为5
TCPServer中的server_activate()

 

  这个时候没有任何异常会抛出的,所以我们已经跑完了整个实例化的流程。并将其赋值给s1

  现在我们看一下s1__dict__字典,再接着进行源码分析。

{'server_address': ('0.0.0.0', 6666), 'RequestHandlerClass': <class '__main__.MyServer'>, '_BaseServer__is_shut_down': , '_BaseServer__shutdown_request': False, 'socket': '0.0.0.0', 6666)>}
s1的__dict__

 

server_forever()启动服务分析


  我们接着来看下一条代码。

s1.serve_forever()

 

  还是老规矩,由于s1ThreadingTCPServer类的实例对象,所以我们去一层层的找serve_forever(),最后在BaseServer类中找到了。

def serve_forever(self, poll_interval=0.5):
    """注释被我删了"""
    self.__is_shut_down.clear()  # 上面说过了那个Event锁,控制子线程的启动顺序。这里的clear()代表清除,这个不是重点,往下看。
    try:
        # XXX: Consider using another file descriptor or connecting to the
        # socket to wake this up instead of polling. Polling reduces our
        # responsiveness to a shutdown request and wastes cpu at all other
        # times.
        with _ServerSelector() as selector:  
            selector.register(self, selectors.EVENT_READ)# 这里是设置了一个监听类型为读取事件。也就是说当有请求来的时候当前socket对象就会发生反应。

            while not self.__shutdown_request: # 为False,会执行,注意!下面都是死循环了!!!
                ready = selector.select(poll_interval)  # 设置最大监听时间为0.5s
                # bpo-35017: shutdown() called during select(), exit immediately.
                if self.__shutdown_request: # BaseServer类中的类方法,为False,所以不执行这个。
                    break
                if ready: # 代表有链接请求会执行下面的方法
                    self._handle_request_noblock()  # 这儿是比较重要的一个点。我们先来看。

                self.service_actions() 
    finally:
        self.__shutdown_request = False
        self.__is_shut_down.set() # 这里是一个释放锁的行为
BaseServer中的serve_forever()

 

  如果有链接请求,则会执行self._handle_request_noblock()方法,它在哪里呢?刚好这个方法就在BaseServerserve_forever()方法的正下方第4个方法的位置。

def _handle_request_noblock(self):
    """注释被我删了"""
    try:
        request, client_address = self.get_request()  # 这里的这个方法在TCPServer中,它的return值是 self.socket.accept(),就是就是返回了元组然后被解压赋值了。其实到这一步三次握手监听已经开启了。
    except OSError:
        return
    if self.verify_request(request, client_address): # 这个是验证ip和port,返回的始终是True
        try:
            self.process_request(request, client_address) # request 双向链接通道,client_address客户端ip+port。现在我们来找这个方法。
        except Exception:
            self.handle_error(request, client_address)
            self.shutdown_request(request)
        except:
            self.shutdown_request(request)
            raise
    else:
        self.shutdown_request(request)
BaseServer中的_handle_request_noblock()

 

  现在开始查找self.process_request(request, client_address)该方法,还是先从实例对象本身找,找不到去第一父类找。他位于第一父类ThreadingMixIn中。

def process_request(self, request, client_address):
    """Start a new thread to process the request."""
    t = threading.Thread(target = self.process_request_thread,
                         args = (request, client_address))  # 创建子线程!!看这里!
    t.daemon = self.daemon_threads # ThreadingMixIn的类属性,为False
    if not t.daemon and self.block_on_close:  # 第一个值为False,第二个值为True。他们都是ThreadingMixIn的类属性
        if self._threads is None:  # 会执行
            self._threads = []  # 创建了空列表
        self._threads.append(t) # 将当前的子线程添加至空列表中
    t.start()  # 开始当前子线程的运行,即运行self.process_request_thread方法
ThreadingMixIn中的process_request()

 

  我们可以看到,这里的target参数中指定了一个方法self.process_request_thread,其实意思就是说当这个线程tstart的时候会去执行该方法。我们看一下它都做了什么,这个方法还是在ThreadingMixIn类中。

def process_request_thread(self, request, client_address):
    """Same as in BaseServer but as a thread.

    In addition, exception handling is done here.

    """
    try:
        self.finish_request(request, client_address) # 可以看到又执行该方法了,这里我再标注一下,别弄头晕了。request 双向链接通道,client_address客户端ip+port。
    except Exception:
        self.handle_error(request, client_address)
    finally:
        self.shutdown_request(request)  # 它不会关闭这个线程,而是将其设置为wait()状态。
ThreadingMixIn中的 process_request_thread()

 

  self.finish_request()方法,它在BaseServer类中

def finish_request(self, request, client_address):
    """Finish one request by instantiating RequestHandlerClass."""
    self.RequestHandlerClass(request, client_address, self)  # 这里是干嘛?其实就是在进行实例化!
BaseServer中的finish_request
 
 

  self.RequestHandlerClass(request, client_address, self),我们找到self__dict__字典,看看这个到底是什么东西

{'server_address': ('0.0.0.0', 6666), 'RequestHandlerClass': <class '__main__.MyServer'>, '_BaseServer__is_shut_down': , '_BaseServer__shutdown_request': False, 'socket': '0.0.0.0', 6666)>}
s1的__dict__

 

  可以看到,它就是我们传入的那个类,即自定义的MyServer类。我们把request,client_address,以及整个是实例self传给了MyServer的__init__方法。但是我们的MyServer类没有__init__,怎么办呢?去它父类BaseRequestHandler里面找呗。

class BaseRequestHandler:

    """注释被我删了"""

    def __init__(self, request, client_address, server):
        self.request = request  # request 双向链接通道
        self.client_address = client_address  # 客户端ip+port
        self.server = server # 即 实例对象本身。上面的__dict__就是它的__dict__
        self.setup() # 钩子函数,我们可以自己写一个类然后继承`BaseRequestHandler`并覆写其setup方法即可。
        try:
            self.handle()  # 看,自动执行handle
        finally:
            self.finish()  # 钩子函数

    def setup(self):
        pass

    def handle(self):
        pass

    def finish(self):
        pass
BaseRequestHandler中的__init__

 

 

  现在我们知道了,为什么一定要覆写handle方法了吧。

 

socketserver内部调用顺序流程图(基于TCP协议)


 

实例化过程图解

socketserver模块使用与源码分析_第6张图片

 

server_forever()启动服务图解

socketserver模块使用与源码分析_第7张图片

 

扩展:验证链接合法性

  在很多时候,我们的TCP服务端为了防止网络泛洪可以设置一个三次握手验证机制。那么这个验证机制的实现其实也是非常简单的,我们的思路在于进入通信循环之前,客户端和服务端先走一次链接认证,只有通过认证的客户端才能够继续和服务端进行链接。

  下面就来看一下具体的实现步骤。

 

#_*_coding:utf-8_*_
__author__ = 'Linhaifeng'
from socket import *
import hmac,os

secret_key=b'linhaifeng bang bang bang'
def conn_auth(conn):
    '''
    认证客户端链接
    :param conn:
    :return:
    '''
    print('开始验证新链接的合法性')
    msg=os.urandom(32)  # 新方法,生成32位随机Bytes类型的值
    conn.sendall(msg)
    h=hmac.new(secret_key,msg)
    digest=h.digest()
    respone=conn.recv(len(digest))
    return hmac.compare_digest(respone,digest) # 对比结果为True或者为False

def data_handler(conn,bufsize=1024):
    if not conn_auth(conn):
        print('该链接不合法,关闭')
        conn.close()
        return
    print('链接合法,开始通信')
    while True:
        data=conn.recv(bufsize)
        if not data:break
        conn.sendall(data.upper())

def server_handler(ip_port,bufsize,backlog=5):
    '''
    只处理链接
    :param ip_port:
    :return:
    '''
    tcp_socket_server=socket(AF_INET,SOCK_STREAM)
    tcp_socket_server.bind(ip_port)
    tcp_socket_server.listen(backlog)
    while True:
        conn,addr=tcp_socket_server.accept()
        print('新连接[%s:%s]' %(addr[0],addr[1]))
        data_handler(conn,bufsize)

if __name__ == '__main__':
    ip_port=('127.0.0.1',9999)
    bufsize=1024
    server_handler(ip_port,bufsize)
Server端
#_*_coding:utf-8_*_
__author__ = 'Linhaifeng'
from socket import *
import hmac,os

secret_key=b'linhaifeng bang bang bang'
def conn_auth(conn):
    '''
    验证客户端到服务器的链接
    :param conn:
    :return:
    '''
    msg=conn.recv(32) # 拿到随机位数
    h=hmac.new(secret_key,msg) # 掺盐
    digest=h.digest()
    conn.sendall(digest)

def client_handler(ip_port,bufsize=1024):
    tcp_socket_client=socket(AF_INET,SOCK_STREAM)
    tcp_socket_client.connect(ip_port)

    conn_auth(tcp_socket_client)

    while True:
        data=input('>>: ').strip()
        if not data:continue
        if data == 'quit':break

        tcp_socket_client.sendall(data.encode('utf-8'))
        respone=tcp_socket_client.recv(bufsize)
        print(respone.decode('utf-8'))
    tcp_socket_client.close()

if __name__ == '__main__':
    ip_port=('127.0.0.1',9999)
    bufsize=1024
    client_handler(ip_port,bufsize)
Client端

 

  到这里已经很简单了,服务器将随机数给客户机发过去,客户机收到后也用自家的盐与随机数加料,再使用digest()将它转化为字节,直接发送了回来然后客户端通过hmac.compare_digest()方法验证两个的值是否相等,如果不等就说明盐不对。客户机不合法服务端将会关闭与该客户机的双向链接通道。

你可能感兴趣的:(socketserver模块使用与源码分析)