在讲解 socketserver 模块之前先补充一下上一章节的一个示例:
实现客户端从服务端下载文件的功能,能hash校验(Windows和Linux测试成功,代码比较low仅供观望)
# coding=utf-8
from socket import *
import json
import struct
import os,hashlib
server = socket(AF_INET,SOCK_STREAM)
# server.bind(("192.168.12.222",8090))
server.bind(("127.0.0.1",8090))
server.listen(5)
while 1:
print("connection...")
conn,addr = server.accept()
print(f"from {addr} conn")
while 1:
try:
file_path = conn.recv(1024)
file_path = os.path.normpath(file_path.decode("utf-8"))
if not os.path.isfile(file_path):
conn.send("4044".encode("utf-8"))
else:
file_size = os.path.getsize(file_path)
file_name = os.path.basename(file_path)
m = hashlib.md5()
m.update(str(file_size).encode("utf-8"))
md5 = m.hexdigest()
header_dic = {
"file_name":file_name,"file_size":file_size,"hash":md5}
header_json = json.dumps(header_dic)
header_bytes = header_json.encode("utf-8")
header_bytes_len = struct.pack("i",len(header_bytes))
conn.send(header_bytes_len)
conn.send(header_bytes)
with open(file_path,"rb")as f:
for line in f:
conn.send(line)
except Exception:
break
# coding=utf-8
from socket import *
import json
import struct
import os
import hashlib
# 打印进度条
def progress(percent, symbol='█', width=40):
if percent > 1: # 超过 100% 的时候让其停在 1
percent = 1 # 可以避免进度条溢出
show_progress = ("▌%%-%ds▌" % width) % (int(percent * width) * symbol)
print("\r%s %.2f%%" % (show_progress, percent * 100), end='')
client = socket(AF_INET,SOCK_STREAM)
# client.connect(("192.168.12.222",8090))
client.connect(("127.0.0.1",8090))
while True:
file_path = input("Please enter the file path(q/exit)>>").strip()
if file_path.lower() == "q":break
if len(file_path) == 0:continue
to_path = input("Please enter the save directory(q/back)>>").strip()
if to_path.lower() == "q":continue
if not os.path.isdir(to_path):
print("not find");continue
else:
file_name = input("Please enter filename(q/back)>>").strip()
if file_name.lower() == "q":continue
goal_path = os.path.join(to_path,file_name)
client.send(file_path.encode("utf-8"))
bytes_4 = client.recv(4)
if bytes_4.decode("utf-8") == "4044":
print("not find");continue
else:
header_bytes_len = struct.unpack("i",bytes_4)[0]
header_bytes = client.recv(header_bytes_len)
header_dic = json.loads(header_bytes.decode("utf-8"))
date_len = header_dic["file_size"]
hash_md5 = header_dic["hash"]
recv_len = 0
with open(goal_path,"wb")as f:
while 1:
date = client.recv(1024)
recv_len += len(date)
percent = recv_len / date_len # 接收的比例
progress(percent, width=40) # 进度条的宽度40
f.write(date)
if recv_len == date_len: break
m = hashlib.md5()
m.update(str(os.path.getsize(goal_path)).encode("utf-8"))
if hash_md5 == m.hexdigest(): # hash 值校验
print("\nHash auth succeed\nFile saved...")
else:
os.remove(goal_path) # 校验失败内容删除
print("Hash auth failed!!")
基于 TCP 的套接字, 关键就是两个循环, 一个是连接循环, 另一个是通信循环, 分成两件事去做
socketserver 模块中有两大类, 一个是 server 类, 专门干连接的事, 一个是 request 类, 专门干通信的事
目前只是简单使用 socketserver 模块来实现并发效果, 后面章节再深入研究
import socketserver
class MyRequestHandler(socketserver.BaseRequestHandler): # 必须继承这个类来使用它的功能
def handle(self): # 用于通信循环
while True:
try:
data = self.request.recv(1024)
if len(data) == 0:break
self.request.send(data.upper())
except ConnectionResetError:
break
self.request.close()
# 做绑定 ip和端口并设置监听的事, "bind_and_activate" 默认等于 "True"
s = socketserver.ThreadingTCPServer(("127.0.0.1",8089),MyRequestHandler,bind_and_activate=True)
s.serve_forever() # 用于建立连接, 之后交给 handle 进行通信循环
from socket import *
client = socket(AF_INET,SOCK_STREAM)
client.connect(("127.0.0.1",8089))
while True:
msg = input(">>").strip()
if len(msg) == 0:continue
client.send(msg.encode("utf-8"))
data = client.recv(1024)
print(data.decode("utf-8"))
client.close()
import socketserver
class MyRequestHandle(socketserver.BaseRequestHandler):
def handle(self):
while True:
date,conn = self.request
print(f"来自[{self.client_address}]的信息 : {date.decode('utf-8')}")
conn.sendto(date.upper(),self.client_address)
s = socketserver.ThreadingUDPServer(("127.0.0.1",8080),MyRequestHandle)
s.serve_forever()
from socket import *
client = socket(AF_INET,SOCK_DGRAM)
while True:
date = input(">>").strip()
client.sendto(date.encode("utf-8"),("127.0.0.1",8080))
res,addr = client.recvfrom(1024)
print(f"来自服务端的消息 : {res.decode('utf-8')}")