基于ssh实现文件自动上传

实现ssh上传文件类封装。

#!/usr/bin/env python
# coding: utf-8

import os
import time
import hashlib
import warnings
import paramiko
from easydict import EasyDict
from typing import List, Text

path = os.path
warnings.filterwarnings("ignore")

SSHConn基础类功能原生逻辑封装:

class SSHConn(object):
    def __init__(self, ip: str, port: int, username: str, password: str):
        self._transport = paramiko.Transport((ip, port))
        self._transport.connect(username=username, password=password)

    def upload(self, src: str, dst: str):
        sftp = paramiko.SFTPClient.from_transport(self._transport)
        try:
            sftp.put(src, dst, self.cal_size)
        except Exception as e:
            # sftp.mkdir(os.path.split(dst)[0])
            self.cmd("mkdir -p %s" % os.path.split(dst)[0])
            sftp.put(src, dst, self.cal_size)

    def download(self, src: str, dst: str):
        sftp = paramiko.SFTPClient.from_transport(self._transport)
        try:
            sftp.get(src, dst, self.cal_size)
        except Exception as e:
            # sftp.mkdir(os.path.split(dst)[0])
            os.makedirs(os.path.split(dst)[0])
            sftp.get(src, dst, self.cal_size)

    def cmd(self, command: str):
        ssh = paramiko.SSHClient()
        ssh._transport = self._transport
        try:
            stdin, stdout, stderr = ssh.exec_command(command)
        except Exception as e:
            print("execute command %s error, error message is %s" % (command, e))
            return ""
        return stdout.read().decode('utf-8')

    @staticmethod
    def cal_size(*args, **kwargs):
        # print("args:", args)
        # print("kwargs:", kwargs)
        pass

    def __del__(self):
        # del instance, callback
        self._transport.close()
        # print("SSHConn del:")

文件夹上传、文件同步类封装:

class SFTPTransfer(SSHConn):
    def __init__(self, ip: str, port: int, username: str, password: str):
        super(SFTPTransfer, self).__init__(ip, port, username, password)
        self.__excludes = []
        self.__file_caches = {}
        self.__filter_caches = {}

    def exclude(self, excludes: List[Text]):
        self.__excludes = [v.replace("\\", "/") for v in excludes]

    def filter(self, root_dir, file_path):
        if file_path in self.__filter_caches:
            return self.__filter_caches[file_path]

        ret = True
        sub_path = file_path[len(root_dir) + 1:].replace("\\", "/")
        for ex in self.__excludes:
            if sub_path.startswith(ex):
                ret = False
                break
            elif ex.startswith("*."):
                if sub_path.endswith(ex[1:]):
                    ret = False
                    break
        self.__filter_caches[file_path] = ret
        return ret

    def uploads(self, src: str, dst: str):
        if os.path.isdir(src):
            file_paths = self.walk_file(src)
            for idx, src_path in enumerate(file_paths):
                dst_path = path.join(dst, src_path[len(src) + 1:]).replace("\\", "/")
                print("upload:", dst_path)
                self.upload(src_path, dst_path)
        else:
            self.upload(src, dst)

    def sync(self, src: str, dst: str):
        while True:
            if not os.path.isdir(src):
                if self.is_modify(src):
                    self.upload(src, dst)
            else:
                trans_cnt = 0
                file_paths = [file for file in self.walk_file(src) if self.filter(src, file)]
                for idx, src_path in enumerate(file_paths):
                    if self.is_modify(src_path):
                        dst_path = path.join(dst, src_path[len(src) + 1:]).replace("\\", "/")
                        print("upload:", dst_path)
                        self.upload(src_path, dst_path)
                        trans_cnt += 1
                if trans_cnt:
                    print("sync total file %d." % trans_cnt)

            time.sleep(1)

    def is_modify(self, file_path):
        key = file_path.replace("\\", "_").replace("/", "_").replace(".", "_")
        stat = os.stat(file_path)
        if key not in self.__file_caches:
            self.__file_caches[key] = EasyDict(st_size=stat.st_size,  # 文件大小
                                               st_mtime=stat.st_mtime,  # 最后修改时间
                                               md5=self.md5(file_path))  # MD5
            return True

        info = self.__file_caches[key]
        if stat.st_mtime == info.st_mtime and stat.st_size == info.st_size:
            return False
        if self.md5(file_path) == info.md5:
            return False

        self.__file_caches[key].st_mtime = stat.st_mtime
        self.__file_caches[key].st_size = stat.st_size
        self.__file_caches[key].md5 = info.md5
        return True

    @staticmethod
    def walk_file(dir_path: str) -> list:
        file_paths = []
        for dirpath, dirnames, filenames in os.walk(dir_path):
            for filename in filenames:
                file_paths.append(os.path.join(dirpath, filename))
        return file_paths

    @staticmethod
    def md5(file_path, bytes=1024):
        md5_1 = hashlib.md5()  # 创建一个md5算法对象
        with open(file_path, 'rb') as f:  # 打开一个文件,必须是'rb'模式打开
            while True:
                data = f.read(bytes)  # 由于是一个文件,每次只读取固定字节
                if data:  # 当读取内容不为空时对读取内容进行update
                    md5_1.update(data)
                else:
                    # 当整个文件读完之后停止update
                    break
        return md5_1.hexdigest()  # 获取这个文件的MD5值

    def __del__(self):
        super(SFTPTransfer, self).__del__()
        # print("SFTPTransfer del:")

测试代码:

if __name__ == "__main__":
    sftp = SFTPTransfer(ip, port, user, pass)
    sftp.exclude(["datasets", "logs", "weights", ".idea", "*.pyc"])
    sftp.sync(r"H:\EfficientDet", "/home/EfficientDet")
    pass

 

你可能感兴趣的:(python包)