SQLAlchemy 使用封装实例

SQLAlchemy 使用封装实例_第1张图片

类封装

database.py

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import json
import logging
from datetime import datetime

from core.utils import classlock, parse_bool
from core.config import (
    MYSQL_HOST,
    MYSQL_PORT,
    MYSQL_USER,
    MYSQL_PASS,
    MYSQL_DATABASE,
    MYSQL_TIMEOUT
)

from sqlalchemy import create_engine, Column, desc, not_, func
from sqlalchemy import Integer, String, Boolean, DateTime, Text, Enum     # Text存储大不固定长的字符串
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.exc import SQLAlchemyError

Base = declarative_base()

log = logging.getLogger("log")

SCHEMA_VERSION = "1.0.0"


class User(Base):
    __tablename__ = "user"

    id = Column(Integer(), primary_key=True)
    file_size = Column(Integer(), nullable=False)      # nullable 不可为空
    md5 = Column(String(32), nullable=False)
    crc32 = Column(String(8), nullable=False)
    sha1 = Column(String(40), nullable=False)
    sha256 = Column(String(64), nullable=False)
    sha512 = Column(String(128), nullable=False)
    memory = Column(Boolean, nullable=False, default=False)
    ssdeep = Column(String(255), nullable=True)
    start_time = Column(DateTime(timezone=False), nullable=True, default=datetime.now)

    def __repr__(self):   # 查询返回的结果
        return "".format(self.id, self.sha256)

    def to_dict(self):
        """将对象转换为dict.
        @return: dict
        """
        d = {}
        for column in self.__table__.columns:
            d[column.name] = getattr(self, column.name)
        return d

    def to_json(self):
        """将对象转换为JSON.
        @return: JSON data
        """
        return json.dumps(self.to_dict())

class Version(Base):
    """用于确定实际数据库架构发布的表."""
    __tablename__ = "version"

    version_num = Column(String(32), nullable=False, primary_key=True)

class Database(object):
    """
    分析队列数据库
    此类处理为内部队列创建数据库用户经营它还提供了一些与之交互的功能
    """

    def __init__(self, schema_check=True, echo=False):
        """
        @param dsn: 数据库连接字符串.
        @param schema_check: 禁用或启用数据库架构版本检查.
        @param echo: echo sql 查询.
        """
        self._lock = None
        self.schema_check = schema_check
        self.echo = echo

    def connect(self, schema_check=None, dsn=None, create=True):
        """连接到数据库后端."""
        if schema_check is not None:
            self.schema_check = schema_check

        if not dsn:
            dsn = "mysql://{0}:{1}@{2}:{3}/{4}".format(MYSQL_USER, MYSQL_PASS, MYSQL_HOST, MYSQL_PORT, MYSQL_DATABASE)
            #dsn = "mysql://{username}:{password}@{hostname}:{port}/{database}"

        self._connect_database(dsn)

        # 禁用SQL日志记录。打开它进行调试.
        self.engine.echo = self.echo

        # 连接超时.
        self.engine.pool_timeout = MYSQL_TIMEOUT

        # 获取数据库会话.
        self.Session = sessionmaker(bind=self.engine)

        if create:
            self._create_tables()

    def _create_tables(self):
        """创建所有数据库表等."""
        try:
            Base.metadata.create_all(self.engine)
        except SQLAlchemyError as e:
            raise ("无法创建或连接到数据库: %s" % e)

        # 处理架构版本控制.
        # TODO: it's a little bit dirty, needs refactoring.
        tmp_session = self.Session()
        if not tmp_session.query(Version).count():
            # 设置数据库架构版本.
            tmp_session.add(Version(version_num=SCHEMA_VERSION))
            try:
                tmp_session.commit()
            except SQLAlchemyError as e:
                raise ("无法设置架构版本: %s" % e)
                tmp_session.rollback()
            finally:
                tmp_session.close()
        else:
            # 检查数据库版本是否为预期版本.
            last = tmp_session.query(Version).first()
            tmp_session.close()
            if last.version_num != SCHEMA_VERSION and self.schema_check:
                log.warning(
                    "数据库架构版本不匹配:找到 %s,应为 %s.",
                    last.version_num, SCHEMA_VERSION
                )
                log.error(
                    "(可选)进行备份,然后通过运行migrate应用最新的数据库迁移。"
                )
                sys.exit(1)

    def __del__(self):
        """断开连接池."""
        self.engine.dispose()

    def _connect_database(self, connection_string):
        """连接到数据库.
        @param connection_string: 指定数据库的连接字符串
        """
        try:
            if connection_string.startswith("sqlite"):
                # 使用“check_same_thread”在多个线程上禁用sqlite安全检查.
                self.engine = create_engine(connection_string, connect_args={"check_same_thread": False})
            elif connection_string.startswith("postgres"):
                # 禁用SSL模式以避免使用sqlalchemy和多进程时出现一些错误.
                # See: http://www.postgresql.org/docs/9.0/static/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
                # TODO 检查这是否仍然相关。特别是假设我们不再使用多处理.
                self.engine = create_engine(connection_string, connect_args={"sslmode": "disable"})
            else:
                self.engine = create_engine(connection_string)
        except ImportError as e:
            lib = str(e).split()[-1].strip("'")

            if lib == "MySQLdb":
                log.error(
                    "缺少MySQL数据库驱动程序(在Linux上使用 `pip install mysql-python` 安装,或在Windows上使用 `pip-install mysqlclient`)"
                )

            if lib == "psycopg2":
                log.error(
                    "缺少PostgreSQL数据库驱动程序 (使用 `pip install psycopg2`)"
                )

            log.error(
                "缺少未知的数据库驱动程序,无法导入 %s" % lib
            )
            sys.exit(-1)

    @classlock
    def add_user(self, file_size, md5, crc32, sha1, sha256, sha512, memory, ssdeep=None):
        session = self.Session()

        # 将空字符串和None值转换为有效的int
        # if not timeout:
        #     timeout = 0
        # if not priority:
        #     priority = 1
        #
        try:
            memory = parse_bool(memory)
        except ValueError:
            memory = False
        #
        # try:
        #     enforce_timeout = parse_bool(enforce_timeout)
        # except ValueError:
        #     enforce_timeout = False

        user = User()
        user.file_size = file_size
        user.md5 = md5
        user.crc32 = crc32
        user.sha1 = sha1
        user.sha256 = sha256
        user.sha512 = sha512
        user.memory = memory
        user.ssdeep = ssdeep

        session.add(user)

        try:
            session.commit()
        except SQLAlchemyError as e:
            log.error("数据库添加 user 错误: {0}".format(e))
            return False
        finally:
            session.close()
        return True

    @classlock
    def select_user(self, id=None):
        session = self.Session()

        try:
            search = session.query(User)

            if id:
                search = search.filter_by(id=id)

            # 排序
            # search = search.order_by(id)
            # search = search.order_by(desc(User.id))  倒叙

            tasks = search.all()

            return tasks
        except SQLAlchemyError as e:
            log.error("数据库查看所有 user 错误: {0}".format(e))
            return []
        finally:
            session.close()

    @classlock
    def update_user(self, id, new_file_size):
        session = self.Session()

        try:
            search = session.query(User).filter(User.id == id).first()
            search.file_size = new_file_size
            # session.query(User).filter(User.id == 1).update({'file_size': 'new_file_size'})

            session.commit()
        except SQLAlchemyError as e:
            log.error("数据库更新 user 错误: {0}".format(e))
            return False
        finally:
            session.close()
        return True

    @classlock
    def delete_user(self, id):
        session = self.Session()

        try:
            search = session.query(User).filter(User.id == id).first()
            if search:
                session.delete(search)
                session.commit()
        except SQLAlchemyError as e:
            log.error("数据库删除 user 错误: {0}".format(e))
            return False
        finally:
            session.close()
        return True

调用运行

merage.py

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import time
import logging

from core.database import Database

log = logging.getLogger("log")

class Merge(object):
    def __init__(self):
        db = Database()
        db.connect()

        res = db.add_user(32, "11", "22", "33", "44", "55", "off")
        if not res:
           print("添加错误")

        print(db.select_user())

        # res = db.update_user(1, 50)
        # if not res:
        #     print("更新错误")

        # res = db.delete_user(2)
        # if not res:
        #     print("删除错误")


if __name__ == '__main__':
    merge = Merge()

你可能感兴趣的:(Sql,python,sql)