Python学习 - 编写自己的ORM(2)

上一篇文章简单的实现了ORM(对象关系模型),这一篇文章主要实现简单的MySQL数据库操作。

想要操作数据库,首先要建立一个数据库连接。下面定义一个创建数据库连接的函数,得到一个连接叫做engine。

def create_engine(user,password,database,host='127.0.0.1',port=3306,**kw):

    import mysql.connector

    global engine

    if engine is not None:

        raise DBError('Engine is already initialized.')

    params = dict(user=user,password=password,database=database,host=host,port=port)

    defaults = dict(use_unicode=True,charset='utf8',collation='utf8_general_ci',autocommit=False)

    #print ('%s %s %s %s %s') % (user,password,database,host,port)

    for k,v in defaults.iteritems():

        params[k] = kw.pop(k,v)

    params.update(kw)

    params['buffered'] = True

    engine = mysql.connector.connect(**params)

    cursor = engine.cursor()

 有了连接就可以对数据库进行操作了。下面写了几个函数,可以对数据库进行查询和插入操作。

def _select(sql,first,*args):

    cursor = None

    sql = sql.replace('?','%s')

    global engine

    try:

        cursor = engine.cursor()

        cursor.execute(sql,args)

        if cursor.description:

            names = [x[0] for x in cursor.description]

        if first:

            values = cursor.fetchone()

            if not values:

                return None

            return Dict(names,values)

        return [Dict(names,x) for x in cursor.fetchall()]

    finally:

        if cursor:

            cursor.close()



def select_one(sql,*args):

    return _select(sql,True,*args)



def select(sql,*args):

    return _select(sql,False,*args)



def _update(sql,*args):

    cursor = None

    global engine

    sql = sql.replace('?','%s')

    print sql

    try:

        cursor = engine.cursor()

        cursor.execute(sql,args)

        r = cursor.rowcount

        engine.commit()

        return r

    finally:

        if cursor:

            cursor.close()



def insert(table,**kw):

    cols, args = zip(*kw.iteritems())

    sql = 'insert into %s (%s) values(%s)' % (table,','.join(['%s' % col for col in cols]),','.join(['?' for i in range(len(cols))]))

    print ('sql %s args %s' % (sql, str(args)))

    return _update(sql,*args)

到这里,基本的数据库操作已经完成了。但是,根据廖雪峰的教程,这还远远不够。

  • 如果要在一个数据库连接中实现多个操作,上面的代码效率很低,没次执行玩一条语句,就需要重新分配一个连接。
  • 在一次事务中执行多条操作也是一样效率低下。
  • 如果服务器为不同用户数据库请求都分配一个线程来建立连接,但是在进程中,连接是可供享使用的。这样问题就来了,导致数据库操作可能异常。

针对第三个问题,应该使每个连接是每个线程拥有的,其它线程不能访问,使用threading.local。首先定义一个类,来保存数据库的上下文:

class _DbCtx(threading.local):



    def __init__(self):

        self.connection = None

        self.transactions = 0



    def is_init(self):

        return not self.connection is None



    def init(self):

        self.connection = engine # 创建数据库连接

        self.transactions = 0



    def cleanup(self):

        self.connection.cleanup()

        self.connection = None



    def cursor(self):

        return self.connection.cursor()

 上面的代码有一个错误。因为Python的赋值语句只是将一个对象的引用传给一个变量,就如上面代码中 init函数中 self.connection = engine。表明self.connection和engine都指向一个数据库连接的对象。如果将self.connection给cleanup了,那么engine指向的对象也被cleanup了。下图是一个例子:

Python学习 - 编写自己的ORM(2) a是类foo实例的一个引用,执行b=a后,在执行b.clean(),此时应该只是b的v值被更改为0,但是执行a.v却发现v的值也变为0了。

下面是最后的代码,只是封装了最底层的数据库操作,代码也写的很涨,虽然是模仿廖雪峰的代码。

# -*- coding: utf-8 -*-

import time, uuid, functools, threading, logging



class Dict(dict):

    '''

    Simple dict but support access as x.y style.



    '''

    def __init__(self, names=(), values=(), **kw):

        super(Dict, self).__init__(**kw)

        for k, v in zip(names, values):

            self[k] = v



    def __getattr__(self, key):

        try:

            return self[key]

        except KeyError:

            raise AttributeError(r"'Dict' object has no attribute '%s'" % key)



    def __setattr__(self, key, value):

        self[key] = value

class DBError(Exception):

    pass

class MultiColumnsError(Exception):

    pass

engine = None

class _DbCtx(threading.local):



    def __init__(self):

        self.connection = None

        self.transactions = 0



    def is_init(self):

        return not self.connection is None



    def init(self):

        self.connection = engine

        self.transactions = 0



    def cleanup(self):

        self.connection = None

    

    def cursor(self):

        return self.connection.cursor()



def create_engine(user,password,database,host='127.0.0.1',port=3306,**kw):

    import mysql.connector

    global engine

    if engine is not None:

        raise DBError('Engine is already initialized.')

    params = dict(user=user,password=password,database=database,host=host,port=port)

    defaults = dict(use_unicode=True,charset='utf8',collation='utf8_general_ci',autocommit=False)

    #print ('%s %s %s %s %s') % (user,password,database,host,port)

    for k,v in defaults.iteritems():

        params[k] = kw.pop(k,v)

    params.update(kw)

    params['buffered'] = True

    engine = mysql.connector.connect(**params)

    print type(engine)



_db_ctx = _DbCtx()

class _ConnectionCtx(object):



    def __enter__(self):

        self.should_cleanuo = False

        if not _db_ctx.is_init():

            cursor = engine.cursor()

            _db_ctx.init()

            self.should_cleanup = True

        return self



    def __exit__(self,exctype,excvalue,traceback):

        if self.should_cleanup:

            _db_ctx.cleanup()



def with_connection(func):

    @functools.wraps(func)

    def _wrapper(*args,**kw):

        with _ConnectionCtx():

            return func(*args, **kw)

    return _wrapper



def _select(sql,first,*args):

    cursor = None

    sql = sql.replace('?','%s')

    global _db_ctx

    try:

        cursor = _db_ctx.cursor()

        cursor.execute(sql,args)

        if cursor.description:

            names = [x[0] for x in cursor.description]

        if first:

            values = cursor.fetchone()

            if not values:

                return None

            return Dict(names,values)

        return [Dict(names,x) for x in cursor.fetchall()]

    finally:

        if cursor:

            cursor.close()

@with_connection

def select_one(sql,*args):

    return _select(sql,True,*args)

@with_connection

def select_int(sql,*args):

    d = _select(sql,True,*args)

    if len(d) != 1:

        raise MultoColumnsError('Except only one column.')

    return d.values()[0]

@with_connection

def select(sql,*args):

    global engine

    print type(engine)

    return _select(sql,False,*args)

@with_connection

def _update(sql,*args):

    cursor = None

    global _db_ctx 

    sql = sql.replace('?','%s')

    print sql

    try:

        cursor = _db_ctx.cursor()

        cursor.execute(sql,args)

        r = cursor.rowcount

        engine.commit()

        return r

    finally:

        if cursor:

            cursor.close()



def insert(table,**kw):

    cols, args = zip(*kw.iteritems())

    sql = 'insert into %s (%s) values(%s)' % (table,','.join(['%s' % col for col in cols]),','.join(['?' for i in range(len(cols))]))

    print ('sql %s args %s' % (sql, str(args)))

    return _update(sql,*args)



create_engine(user='root',password='z5201314',database='test')

u1 = select_one('select * from user where id=?',1)

print 'u1'

print u1

print 'start selet()...'

u2 = select('select * from user')

for item in u2:

    print ('%s %s' % (item.name,item.id))

print 'name:%s id: %s' % (u1.name,u1.id)

 

你可能感兴趣的:(python)