Python教程链接:Day 3 - 编写ORM
ORM
ORM(Object Relational Mapping)即对象关系映射,通过代码描述程序中对象和数据库对应的元数据,将对象持久化到数据库中。
创建连接池
创建链接池的作用是每个HTTP来请求数据库的时候,都能从连接池直接过去数据库的连接,而不需要每次都打开关闭数据库。
# 通过关键字参数**kw接受连接数据库需要的对应参数来创建连接池
async def create_pool(loop, **kw):
logging.info('create database connection pool...')
global __pool
__pool = await aiomysql.create_pool(
host=kw.get('host', 'localhost'),
port=kw.get('port', 3306),
user=kw['user'],
password=kw['password'],
db=kw['db'],
# 这里是utf8,而不是utf-8,别写错了
charset=kw.get('charset', 'utf8'),
autocommit=kw.get('autocommit', True),
maxsize=kw.get('maxsize', 10),
minsize=kw.get('minsize', 1),
loop=loop
)
对应的数据库操作语句的封装
Select
Select
语句与Update
,Insert
,Delete
语句分开进行封装,因为他返回的是所查询的数据库对应记录,而Update
,Insert
,Delete
语句返回的是所影响的行数。
# 传入SQL语句,参数,大小可选
async def select(sql, args, size=None):
log(sql)
global __pool # 获取全局的连接池__pool
async with __pool.get() as conn: # 打开连接池
async with conn.cursor(aiomysql.DictCursor) as cur: # 创建游标,DictCursor的作用是使查询返回的结果为字典格式
await cur.execute(sql.replace('?', '%s'), args or ()) # 执行SQL语句,将SQL语句的'?'占位符替换成MySQL的'%s'占位符
if size: # 如果有传入size,则返回对应个数的结果,size为None则返回全部
rs = await cur.fetchmany(size)
else:
rs = await cur.fetchall()
logging.info('rows returned: %s' % len(rs))
return rs
Update、Insert、Delete
这三种SQL语句的执行所需要的参数都相同,并且只返回一个整数表示所影响的行数。
# 传入SQL语句,参数,默认自动提交事务
async def execute(sql, args, autocommit=True):
log(sql)
async with __pool.get() as coon: # 打开连接池
if not autocommit: # 如果autocommit为False,conn.begin()开始事务
await coon.begin()
try: # 无论是否自动提交事务,都执行try中的代码
async with coon.cursor(aiomysql.DictCursor) as cur: # 创建游标,DictCursor的作用是使查询返回的结果为字典格式
await cur.execute(sql.replace('?', '%s'), args)
affected = cur.rowcount # 通过rowcount得到SQL语句影响的行数
if not autocommit:
await coon.commit()
except BaseException as e: # 处理出错情况
if not autocommit:
await coon.roolback() # 回滚操作
raise
return affected
Field基类及其子类
定义一个Field类型和其子类对应数据库中不同的类型,String
、Integer
、Boolean
、Float
、Text
等。
class Field(object):
def __init__(self, name, column_type, primary_key, default):
self.name = name # 字段名
self.column_type = column_type # 列类型
self.primary_key = primary_key # 主键
self.default = default # 默认值
def __str__(self):
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
class StringField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
super().__init__(name, ddl, primary_key, default)
class BooleanField(Field):
def __init__(self, name=None, default=False):
super().__init__(name, 'boolean', False, default)
class IntegerField(Field):
def __init__(self, name=None, primary_key=False, default=0):
super().__init__(name, 'bigint', primary_key, default)
class FloatField(Field):
def __init__(self, name=None, primary_key=False, default=0.0):
super().__init__(name, 'real', primary_key, default)
class TextField(Field):
def __init__(self, name=None, default=None):
super().__init__(name, 'text', False, default)
元类
通过元类ModelMetaclass可以将具体的子类映射信息读取出来。
class ModelMetaclass(type):
# 调用__init__方法前会调用__new__方法
def __new__(cls, name, bases, attrs): # cls:当前准备创建的类的对象,name:类的名称,bases:类继承的父类集合,attrs:类的方法集合
# 排除Model类本身
if name == 'Model':
return type.__new__(cls, name, bases, attrs)
# 获取table名称,如果未设置,tableName就是类的名字
tableName = attrs.get('__table__', None) or name
logging.info('found model: %s (table: %s)' % (name, tableName))
# 获取所有的Field和主键名
mappings = dict()
fields = []
primaryKey = None
# key是列名,value是field的子类
for k, v in attrs.items():
if isinstance(v, Field):
logging.info(' found mapping:%s ==> %s' % (k, v))
mappings[k] = v
if v.primary_key:
# 找到主键
if primaryKey:
raise RuntimeError('Duplicate primary key for field: %s' % k)
primaryKey = k
else:
fields.append(k)
if not primaryKey:
raise RuntimeError('Primary key not found.')
# 删除类属性
for k in mappings.keys():
attrs.pop(k)
# 保存除主键外的属性名为``(运算出字符串)列表形式
escaped_fields = list(map(lambda f: '`%s`' % f, fields))
attrs['__mappings__'] = mappings # 保存属性和列的映射关系
attrs['__table__'] = tableName
attrs['__primary_key__'] = primaryKey # 主键属性名
attrs['__fields__'] = fields # 除主键外的属性名
# 构造默认的SELECT, INSERT, UPDATE和DELETE语句
# 反引号和repr()函数的功能一致
attrs['__select__'] = 'select `%s`, %s from `%s`' % \
(primaryKey, ', '.join(escaped_fields), tableName)
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % \
(tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
return type.__new__(cls, name, bases, attrs)
基类Model
继承自Model的类,会自动通过ModelMetaclass扫描映射关系,并存储到自身的类属性中__table__
、__mappings__
等。
class Model(dict, metaclass=ModelMetaclass):
def __init__(self, **kw):
super(Model, self).__init__(**kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
def getValue(self, key):
# 返回对象的属性,如果没有对应属性,则会调用__getattr__
return getattr(self, key, None)
def getValueOrDefault(self, key):
value = getattr(self, key, None)
if value is None:
field = self.__mappings__[key]
if field.default is not None:
value = field.default() if callable(field.default) else field.default
logging.debug('using default value for %s: %s' % (key, str(value)))
# 将默认值设置进行
setattr(self, key, value)
return value
# 类方法第一个参数为cls,而实例方法第一个参数为self
@classmethod
async def findAll(cls, where=None, args=None, **kw):
"""find object by where clause"""
sql = [cls.__select__]
if where:
sql.append('where')
sql.append(where)
if args is None:
args = []
orderBy = kw.get('orderBy', None)
if orderBy:
sql.append('order by')
sql.append(orderBy)
limit = kw.get('limit', None)
if limit is not None:
sql.append('limit')
if isinstance(limit, int):
sql.append('?')
args.append(limit)
elif isinstance(limit, tuple) and len(limit) == 2:
sql.append('?', '?')
# extend接受一个iterable参数
args.extend(limit)
else:
raise ValueError('Invalid limit value: %s' % str(limit))
rs = await select(' '.join(sql), args)
return [cls(**r) for r in rs]
@classmethod
async def findNumber(cls, selectField, where=None, args=None):
"""find number by select and where"""
# 将列名重命名为_num_
sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
if where:
sql.append('where')
sql.append(where)
# 限制结果数为1
rs = await select(' '.join(sql), args, 1)
if len(rs) == 0:
return None
return rs[0]['_num_']
@classmethod
async def find(cls, pk):
"""find object by primarykey"""
rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
if len(rs) == 0:
return None
return cls(**rs[0])
async def save(self):
# 获取所有value
args = list(map(self.getValueOrDefault, self.__fields__))
args.append(self.getValueOrDefault(self.__primary_key__))
rows = await execute(self.__insert__, args)
if rows != 1:
logging.warning('failed to insert record: affected rows: %s' % rows)
async def update(self):
args = list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows = await execute(self.__update__, args)
if rows != 1:
logging.warning('faild to update by primary key: affected rows: %s' % rows)
async def remove(self):
args = [self.getValue(self.__primary_key__)]
rows = await execute(self.__delete__, args)
if rows != 1:
logging.warning('faild to remove by primary key: affected rows: %s' % rows)
注意点
Python采用的是代码缩进来区分代码块的,而不是其他语言中的{}。所以写代码的时候要格外注意,刚开始用Python经常会因为这个问题导致程序不按预期执行。
博主因为粗心写错了一个for循环内的return缩进,调试了很久才醒悟。