对SQLAlchemy进行改写。可直接拿到别的项目中使用

  • model这一块用的是SQLAlchemy
  • 继承重写Query查询

第一个:Query的改写

# 所谓的改写并不是在源码中改写,而是继承之后重写这个方法
# 在flask_app>orm>base.py 
from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery
from sqlalchemy import inspect, Column, Integer, SmallInteger, orm
from contextlib import contextmanager

from common.error import NotFoundError
from flask_app.orm import transfer

class SQLAlchemy(_SQLAlchemy):
    @contextmanager
    def auto_commit(self):
        try:
            yield
            self.session.commit()
        except Exception as e:
            db.session.rollback()
            raise e

class Query(BaseQuery):
    def filter_by(self, **kwargs):
        # 这个是每个model中都加入了status。只有等于0才会被找到
        if 'status' not in kwargs.keys():
            kwargs['status'] = 0
        return super(Query, self).filter_by(**kwargs)

    def get_or_404(self, ident):
        rv = self.get(ident)
        if not rv:
            raise NotFoundError(msg="数据不存在")
        return rv

    def first_or_404(self):
        rv = self.first()
        if not rv:
            raise NotFoundError(msg="数据不存在")
        return rv

db = SQLAlchemy(session_options={'autocommit': True},query_class=Query)

# 这样用的话  找不到直接报出异常,很好使用

第二个:Model改写

class BaseModel(db.Model):
    __abstract__ = True

    def insert(self):
        self._before_insert()
        self.try_to_add_ip()
        self.try_to_add_device_info()
        db.session.add(self)
        db.session.flush()
        self._after_insert()
        return self

    def update(self):
        self._before_update()
        db.session.merge(self)
        db.session.flush()
        self._after_update()
        return self

    def delete(self):
        self._before_delete()
        db.session.delete(self)
        db.session.flush()
        self._after_delete()

    def _before_insert(self):
        pass

    def _after_insert(self):
        pass

    def _before_update(self):
        pass

    def _after_update(self):
        pass

    def _before_delete(self):
        pass

    def _after_delete(self):
        pass

    @classmethod
    def load_all_data_field(cls):
        """
        获取类自身所有数据表映射字段名
        :return:
        """
        if hasattr(cls, '__table__'):
            return [c.name for c in cls.__table__.columns]

    def try_to_add_ip(self):
        ip_column = 'ip'
        # 检查是否有ip这个field,如果有并且没有值,则从flask request对象里面取
        if ip_column in self.load_all_data_field():
            ip_val = getattr(self,ip_column)
            if not ip_val:
                from flask import request
                setattr(self,ip_column,request.headers.get('X-Forwarded-For', None) or request.remote_addr)

    def try_to_add_device_info(self):
        device_infos = ['imei', 'mac']
        # 检查是否有ip这个field,如果有并且没有值,则从flask request对象里面取
        for info in device_infos:
            if info in self.load_all_data_field():
                info_val = getattr(self, info)
                if not info_val:
                    from flask import request

                    info_val = request.args.get(info, "")
                    if info_val:
                        setattr(self, info, info_val)

    def to_dict(self, without=(), include=()):
        """
        主要是将model转换为字典返回
        """
        return transfer.orm_obj2dict(self, without, include)

    def update_from_json(self,json_str):
        """
        接受json_str更新原本信息
        """
        return transfer.json_up_orm_obj(json_str,self)

    @classmethod
    def from_dict(cls, dic):
        return transfer.dict2obj(dic, cls)

    def __repr__(self):
        return self.to_json()

    def to_json(self, without=(), include=()):
        return transfer.orm_obj2json(self, without, include)

    def save(self):
        if self.id:
            self.update()
        else:
            self.insert()


class MixinJSONSerializer:
    @orm.reconstructor
    def init_on_load(self):
        self._fields = []
        # self._include = []
        self._exclude = []

        self._set_fields()
        self.__prune_fields()

    def _set_fields(self):
        pass

    def __prune_fields(self):
        columns = inspect(self.__class__).columns
        if not self._fields:
            all_columns = set(columns.keys())
            self._fields = list(all_columns - set(self._exclude))

    def hide(self, *args):
        for key in args:
            self._fields.remove(key)
        return self

    def keys(self):
        return self._fields

    def __getitem__(self, key):
        return getattr(self, key)
  • BaseModel 和MixinJSONSerializer 都是可用于model继承的

你可能感兴趣的:(对SQLAlchemy进行改写。可直接拿到别的项目中使用)