【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing @163.com】
网上关于python orm介绍的文章很多,但是我觉得廖雪峰这个orm是介绍的最好的。下面,我就给出自己对orm的理解。之前廖雪峰给出的code,大家可以通过地址下载的到。
一般我们使用orm都是这么写代码的,
class User(Model):
id = IntegerField('uid')
name = StringField('username')
email = StringField('email')
password = StringField('password')
u = User(id=12345, name='Michael', email='[email protected]', password='my-pwd')
u.save()
这其中,User定义了数据类型,创建的时候直接用User实例化就可以了。这里的save表明了我们希望对User进行的操作,实际调用的时候包括了create、search、update和delete这些操作。
从代码中可以看出,User本身没有定义save函数,那么它只有从父类中继承而来,我们看看save函数在不在Model类中。下面来分析一下Model类,
class Model(dict):
__metaclass__ = ModelMetaclass
def __init__(self, **kw):
super(Model, self).__init__(**kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
def save(self):
fields = []
params = []
args = []
for k, v in self.__mappings__.iteritems():
fields.append(v.name)
params.append('?')
args.append(getattr(self, k, None))
sql = 'insert into %s (%s) values (%s)' % (self.__table__, ','.join(fields), ','.join(params))
print('SQL: %s' % sql)
print('ARGS: %s' % str(args))
Model类果然没有令我们失望,save函数就在里面。从代码内容,我们大体看到不管是fileds、params,还是args都是根据__mapping__数据来获取到的。那么奇怪了,__mapping__又是哪里来的?这个时候__metaclass__就派上用场了。
根据python的定义,如果类中出现了自定义__metaclass__,那么编译器会默认调用对应class下的__new__函数。这个时候,对我们来说,就需要看看ModelMetaclass里面做了些什么,
class ModelMetaclass(type):
def __new__(self, name, bases, attrs):
if name=='Model':
return type.__new__(self, name, bases, attrs)
print('Found model: %s' % name)
mappings = dict()
for k, v in attrs.iteritems():
if isinstance(v, Field):
print('Found mapping: %s ==> %s' % (k, v))
mappings[k] = v
for k in mappings.iterkeys():
attrs.pop(k)
attrs['__mappings__'] = mappings
attrs['__table__'] = name
return type.__new__(self, name, bases, attrs)
从ModelMetaclass这个唯一的__new__函数来看,所有的变量都被函数进行了处理。首先,函数判断name是否为Model,是则不做处理,反之进行改写。其次,函数定义了mappings这个变量,它搜集了attrs中的Field变量,并且将这些变量从attrs踢出去。最后函数将修改后的attrs保存,重新进行type处理。
数据类型的定义比较简单。基本上就是按照父类加子类的办法进行的。唯一需要记住的,就是User类中的变量是类变量,不是实例变量,这一点要考虑清楚。
class Field(object):
def __init__(self, name, column_type):
self.name = name
self.column_type = column_type
def __str__(self):
return '<%s:%s>' % (self.__class__.__name__, self.name)
class StringField(Field):
def __init__(self, name):
super(StringField, self).__init__(name, 'varchar(100)')
class IntegerField(Field):
def __init__(self, name):
super(IntegerField, self).__init__(name, 'bigint')
最后给出完整代码,大家可以直接copy上面的地址下载。这里也给出一份,大家可以直接拿来进行调试练习。代码版权归原作者廖雪峰,请大家保留原作者的文件签名。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
' Simple ORM using metaclass '
__author__ = 'Michael Liao'
class Field(object):
def __init__(self, name, column_type):
self.name = name
self.column_type = column_type
def __str__(self):
return '<%s:%s>' % (self.__class__.__name__, self.name)
class StringField(Field):
def __init__(self, name):
super(StringField, self).__init__(name, 'varchar(100)')
class IntegerField(Field):
def __init__(self, name):
super(IntegerField, self).__init__(name, 'bigint')
class ModelMetaclass(type):
def __new__(self, name, bases, attrs):
if name=='Model':
return type.__new__(self, name, bases, attrs)
print('Found model: %s' % name)
mappings = dict()
for k, v in attrs.iteritems():
if isinstance(v, Field):
print('Found mapping: %s ==> %s' % (k, v))
mappings[k] = v
for k in mappings.iterkeys():
attrs.pop(k)
attrs['__mappings__'] = mappings
attrs['__table__'] = name
return type.__new__(self, name, bases, attrs)
class Model(dict):
__metaclass__ = ModelMetaclass
def __init__(self, **kw):
super(Model, self).__init__(**kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
def save(self):
fields = []
params = []
args = []
for k, v in self.__mappings__.iteritems():
fields.append(v.name)
params.append('?')
args.append(getattr(self, k, None))
sql = 'insert into %s (%s) values (%s)' % (self.__table__, ','.join(fields), ','.join(params))
print('SQL: %s' % sql)
print('ARGS: %s' % str(args))
# testing code:
class User(Model):
id = IntegerField('uid')
name = StringField('username')
email = StringField('email')
password = StringField('password')
u = User(id=12345, name='Michael', email='[email protected]', password='my-pwd')
u.save()