实现ssh上传文件类封装。
#!/usr/bin/env python
# coding: utf-8
import os
import time
import hashlib
import warnings
import paramiko
from easydict import EasyDict
from typing import List, Text
path = os.path
warnings.filterwarnings("ignore")
SSHConn基础类功能原生逻辑封装:
class SSHConn(object):
def __init__(self, ip: str, port: int, username: str, password: str):
self._transport = paramiko.Transport((ip, port))
self._transport.connect(username=username, password=password)
def upload(self, src: str, dst: str):
sftp = paramiko.SFTPClient.from_transport(self._transport)
try:
sftp.put(src, dst, self.cal_size)
except Exception as e:
# sftp.mkdir(os.path.split(dst)[0])
self.cmd("mkdir -p %s" % os.path.split(dst)[0])
sftp.put(src, dst, self.cal_size)
def download(self, src: str, dst: str):
sftp = paramiko.SFTPClient.from_transport(self._transport)
try:
sftp.get(src, dst, self.cal_size)
except Exception as e:
# sftp.mkdir(os.path.split(dst)[0])
os.makedirs(os.path.split(dst)[0])
sftp.get(src, dst, self.cal_size)
def cmd(self, command: str):
ssh = paramiko.SSHClient()
ssh._transport = self._transport
try:
stdin, stdout, stderr = ssh.exec_command(command)
except Exception as e:
print("execute command %s error, error message is %s" % (command, e))
return ""
return stdout.read().decode('utf-8')
@staticmethod
def cal_size(*args, **kwargs):
# print("args:", args)
# print("kwargs:", kwargs)
pass
def __del__(self):
# del instance, callback
self._transport.close()
# print("SSHConn del:")
文件夹上传、文件同步类封装:
class SFTPTransfer(SSHConn):
def __init__(self, ip: str, port: int, username: str, password: str):
super(SFTPTransfer, self).__init__(ip, port, username, password)
self.__excludes = []
self.__file_caches = {}
self.__filter_caches = {}
def exclude(self, excludes: List[Text]):
self.__excludes = [v.replace("\\", "/") for v in excludes]
def filter(self, root_dir, file_path):
if file_path in self.__filter_caches:
return self.__filter_caches[file_path]
ret = True
sub_path = file_path[len(root_dir) + 1:].replace("\\", "/")
for ex in self.__excludes:
if sub_path.startswith(ex):
ret = False
break
elif ex.startswith("*."):
if sub_path.endswith(ex[1:]):
ret = False
break
self.__filter_caches[file_path] = ret
return ret
def uploads(self, src: str, dst: str):
if os.path.isdir(src):
file_paths = self.walk_file(src)
for idx, src_path in enumerate(file_paths):
dst_path = path.join(dst, src_path[len(src) + 1:]).replace("\\", "/")
print("upload:", dst_path)
self.upload(src_path, dst_path)
else:
self.upload(src, dst)
def sync(self, src: str, dst: str):
while True:
if not os.path.isdir(src):
if self.is_modify(src):
self.upload(src, dst)
else:
trans_cnt = 0
file_paths = [file for file in self.walk_file(src) if self.filter(src, file)]
for idx, src_path in enumerate(file_paths):
if self.is_modify(src_path):
dst_path = path.join(dst, src_path[len(src) + 1:]).replace("\\", "/")
print("upload:", dst_path)
self.upload(src_path, dst_path)
trans_cnt += 1
if trans_cnt:
print("sync total file %d." % trans_cnt)
time.sleep(1)
def is_modify(self, file_path):
key = file_path.replace("\\", "_").replace("/", "_").replace(".", "_")
stat = os.stat(file_path)
if key not in self.__file_caches:
self.__file_caches[key] = EasyDict(st_size=stat.st_size, # 文件大小
st_mtime=stat.st_mtime, # 最后修改时间
md5=self.md5(file_path)) # MD5
return True
info = self.__file_caches[key]
if stat.st_mtime == info.st_mtime and stat.st_size == info.st_size:
return False
if self.md5(file_path) == info.md5:
return False
self.__file_caches[key].st_mtime = stat.st_mtime
self.__file_caches[key].st_size = stat.st_size
self.__file_caches[key].md5 = info.md5
return True
@staticmethod
def walk_file(dir_path: str) -> list:
file_paths = []
for dirpath, dirnames, filenames in os.walk(dir_path):
for filename in filenames:
file_paths.append(os.path.join(dirpath, filename))
return file_paths
@staticmethod
def md5(file_path, bytes=1024):
md5_1 = hashlib.md5() # 创建一个md5算法对象
with open(file_path, 'rb') as f: # 打开一个文件,必须是'rb'模式打开
while True:
data = f.read(bytes) # 由于是一个文件,每次只读取固定字节
if data: # 当读取内容不为空时对读取内容进行update
md5_1.update(data)
else:
# 当整个文件读完之后停止update
break
return md5_1.hexdigest() # 获取这个文件的MD5值
def __del__(self):
super(SFTPTransfer, self).__del__()
# print("SFTPTransfer del:")
测试代码:
if __name__ == "__main__":
sftp = SFTPTransfer(ip, port, user, pass)
sftp.exclude(["datasets", "logs", "weights", ".idea", "*.pyc"])
sftp.sync(r"H:\EfficientDet", "/home/EfficientDet")
pass