一旦决定使用异步,则系统每一层都必须是异步,“开弓没有回头箭”。
创建一个全局的连接池,每个HTTP请求都可以从连接池中直接获取数据库连接。使用连接池的好处是不必频繁地打开和关闭数据库连接,而是能复用就尽量复用。
连接池由全局变量__pool存储,缺省情况下将编码设置为utf8,自动提交事务
要执行INSERT、UPDATE、DELETE语句,可以定义一个通用的execute()函数
设计ORM需要从上层调用者角度来设计。我们先考虑如何定义一个User对象,然后把数据库表users和它关联起来。
from orm import Model, StringField, IntegerField
class User(Model):
__table__ = 'users'
id = IntegerField(primary_key=True)
name = StringField()
廖雪峰Python-ORM
廖雪峰Python-SQLite
MySQL
aiomysql
深刻理解Python中的元类(metaclass)
lThings to Know About Python Super
!/usr/bin/env python3
-*- coding: utf-8 -*-
__author__ = 'Summous'
import asyncio, logging
import aiomysql
async def create_pool(loop, **kw):
'''
创建数据库链接池
:param loop:事件循环处理程序
:param kw:数据库配置参数集合
:return:无
缺省情况下将编码设置为utf8,自动提交事务
'''
logging.info('创建数据库链接池...')
# 创建全局变量
global __pool
# 初始化链接池参数
__pool = await aiomysql.create_pool(
host=kw.get('host', 'localhost'),
port=kw.get('port', 3306),
user=kw['user'],
password=kw['password'],
db=kw['db'],
charset=kw.get('charset', 'utf8'),
autocommit=kw.get('autocommit', True),
maxsize=kw.get('maxsize', 10),
minsize=kw.get('minsize', 1),
loop=loop
)
async def select(sql, args, size=None):
'''
数据库查询函数
:param sql: sql语句
:param args: sql语句中的参数
:param size: 要查询的数量
:return: 查询结果
'''
# logging.log(sql, args)
global __pool
async with __pool.get() as conn:
# 创建一个结果为字典的游标
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行sql语句,将sql语句中的'?'替换成'%s'
await cur.execute(sql.replace('?','%s'),args or ())
# 如果指定了数量,就返回指定数量的记录,如果没有,就返回所有记录
if size:
rs = await cur.fetchmany(size)
else:
rs = await cur.fetchall()
logging.info('返回的记录数: %s' % len(rs))
return rs #返回的结果集
async def execute(sql, args, autocommit=True):
'''
Insert、Update、Delete操作的公共执行函数
:param sql:sql语句
:param args:sql参数
:param autocommit:自动提交事务
:return:
'''
# logging.log(sql,args)
async with __pool.get() as conn:
if not autocommit:
await conn.begin()
try:
# 创建一个结果为字典的游标
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行sql语句
await cur.execute(sql.replace('?','%s'),args or ())
# 获取操作的记录数
affected = cur.rowcount
if not autocommit:
await conn.commit()
except BaseException as e:
if not autocommit:
await conn.rollback() #数据回滚
raise
logging.info('返回的记录数: %s' % len(rs))
return affected #返回的结果数
class ModelMetaclass(type):
def __new__(cls, name, bases, attrs):
'''
创建模型与表映射的基类
:param name:类名
:param bases:父类
:param attrs:类的属性列表
:return:模型元类
'''
# 排除Model类本身
if name == 'Model':
return type.__new__(cls, name, bases, attrs)
# 获取表名,如果没有表名则将类名作为表名
tableName = attrs.get('__table__',None) or name
logging.info('模型: %s (表名: %s)' % (name, tableName))
# 获取所有的类属性和主键名:
mappings = dict() # 存储属性名和字段信息的映射关系
fields = [] # 存储所有非主键的属性
primaryKey = None # 存储主键属性
for k,v in attrs.items(): # 遍历attrs(类的所有属性),k为属性名,v为该属性对应的字段信息
if isinstance(v,Field): # 如果v是自己定义的字段类型
logging.info('映射关系:%s ==> %s' % (k,v))
mappings[k] = v # 存储映射关系
if v.primary_key: # 如果该属性是主键
if primaryKey: # 如果primaryKey已经保存了主键,说明主键已经找到了,所以主键重复
raise RuntimeError('主键重复: 在%s中的%s' % (name,k))
primaryKey = k
else: # 如果不是主键,存储到fields中去
fields.append(k)
if not primaryKey: # 如果遍历了所有属性都没有找到主键,则主键没定义
raise RuntimeError('主键未定义:%s',name)
for k in mappings.keys(): # 清空attrs
attrs.pop(k)
# 将fields中属性名以`属性名`的方式装饰起来
escaped_fields = list(map(lambda f: '`%s`' % f, fields))
# 重新设置attrs,类的属性和方法都放在fields,主键属性放在primary_key
attrs['__mappings__'] = mappings # 保存属性和字段信息的映射关系
attrs['__table__'] = tableName # 保存表名
attrs['__primary_key__'] = primaryKey # 主键属性名
attrs['__fields__'] = fields # 除主键外的属性名
# 构造默认的SELECT, INSERT, UPDATE和DELETE语句:
attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (
tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (
tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
return type.__new__(cls, name, bases, attrs)
def create_args_string(num):
'''
用来计算需要拼接多少个占位符
:param num:
:return:
'''
L = []
for n in range(num):
L.append('?')
return ', '.join(L)
class Field(object):
def __init__(self, name, column_type, primary_key, default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default
def __str__(self):
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
class StringField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
super().__init__(name, ddl, primary_key, default)
class BooleanField(Field):
def __init__(self, name=None, default=False):
super().__init__(name, 'boolean', False, default)
class IntegerField(Field):
def __init__(self, name=None, primary_key=False, default=0):
super().__init__(name, 'bigint', primary_key, default)
class FloatField(Field):
def __init__(self, name=None, primary_key=False, default=0.0):
super().__init__(name, 'real', primary_key, default)
class TextField(Field):
def __init__(self, name=None, default=None):
super().__init__(name, 'text', False, default)
class Model(dict, metaclass=ModelMetaclass):
def __init__(self, **kw):
super(Model, self).__init__(**kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model'对象没有属性'%s'" % key)
def __setattr__(self, key, value):
self[key] = value
def getValue(self, key):
return getattr(self, key, None)
def getValueOrDefault(self, key):
value = getattr(self, key, None)
if value is None: # 如果没有找到value
field = self.__mappings__[key] # 从mappings映射集合中找
value = field.default() if callable(field.default) else field.default
logging.debug('使用默认值 %s:%s' % (key, str(value)))
setattr(self, key, value)
return value
@classmethod
async def findAll(cls, where=None, args=None, **kw):
'''
通过where查找多条记录对象
:param where:where查询条件
:param args:sql参数
:param kw:查询条件列表
:return:多条记录集合
'''
sql = [cls.__select__]
# 如果where查询条件存在
if where:
sql.append('where') # 添加where关键字
sql.append(where) # 拼接where查询条件
if args is None:
args = []
orderBy = kw.get('orderBy',None) # 获取kw里面的orderby查询条件
if orderBy: # 如果存在orderby
sql.append('orderBy') # 拼接orderBy字符串
sql.append(orderBy) # 拼接orderBy查询条件
limit = kw.get('limit',None) # 获取limit查询条件
if limit is not None:
sql.append('limit')
if isinstance(limit,int): # 如果limit是int类型
sql.append('?') # sql拼接一个占位符
args.append(limit) # 将limit添加进参数列表,之所以添加参数列表之后再进行整合是为了防止sql注入
elif isinstance(limit,tuple) and len(limit) == 2: # 如果limit是一个tuple类型并且长度是2
sql.append('?,?') # sql语句拼接两个占位符
args.extend(limit) # 将limit添加进参数列表
else:
raise ValueError('limit参数无效:%s' % str(limit))
rs = await select(''.join(sql),args) # 将args参数列表注入sql语句之后,传递给select函数进行查询并返回查询结果
return [cls(**r) for r in rs]
@classmethod
async def findNumber(cls, selectField, where = None, args = None):
'''
查询某个字段的数量
:param selectField: 要查询的字段
:param where: where查询条件
:param args: 参数列表
:return: 数量
'''
sql = ['select count(%s) _num_ from `%s`' % (selectField, cls.__table__)]
if where:
sql.append('where')
sql.append(where)
rs = await select(''.join(sql), args, 1)
return rs[0]['_num_']
@classmethod
async def findById(cls, pk):
'''
通过id查询
:param pk:id
:return: 一条记录
'''
rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
if len(rs) == 0:
return None
return cls(**rs[0])
@classmethod
async def findByColumn(cls, f, cl):
'''
通过指定字段查询
:param f: 要查询的字段
:param cl: 查询字段所对应的值
:return: 一条记录
'''
fi = None
for field in cls.__fields__: # 遍历属性列表看有没有这个属性
if f == field: # 找到了就赋值给fi然后退出循环
fi = field
break
if fi is None:
raise AttributeError('在%s中没有找到该字段:' % cls.__table__)
rs = await select('%s where `%s`=?' % (cls.__select__, fi), [cl], 1)
if len(rs) == 0:
return None
return cls(**rs[0])
async def save(self):
# 将__fields__保存的除主键外的所有属性一次传递到getValueOrDefault函数中获取值
args = list(map(self.getValueOrDefault, self.__fields__))
# 获取主键值
args.append(self.getValueOrDefault(self.__primary_key__))
# 执行insertsql语句
rows = await execute(self.__insert__, args)
if rows != 1:
logging.warning('插入记录失败:受影响的行: %s' % rows)
async def update(self):
args = list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows = await execute(self.__update__,args)
if rows != 1:
logging.warning('更新记录失败:受影响的行:%s' % rows)
async def delete(self):
args = [self.getValue(self.__primary_key__)]
rows = await execute(self.__delete__,args)
if rows != 1:
logging.warning('删除记录失败:受影响的行: %s' % rows)
参考了很多资料和代码,感觉还是没有完全吃透,回头继续研究。