这个例实在是太经典了,整个网上都用这个例子做decorator的经典范例,因为太经典了,所以,我这篇文章也不能免俗。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
fromfunctoolsimportwraps
defmemo(fn):
cache={}
miss=object()
@wraps(fn)
defwrapper(*args):
result=cache.get(args, miss)
ifresultismiss:
result=fn(*args)
cache[args]=result
returnresult
returnwrapper
@memo
deffib(n):
ifn <2:
returnn
returnfib(n-1)+fib(n-2)
|
上面这个例子中,是一个斐波拉契数例的递归算法。我们知道,这个递归是相当没有效率的,因为会重复调用。比如:我们要计算fib(5),于是其分解成fib(4) + fib(3),而fib(4)分解成fib(3)+fib(2),fib(3)又分解成fib(2)+fib(1)…… 你可看到,基本上来说,fib(3), fib(2), fib(1)在整个递归过程中被调用了两次。
而我们用decorator,在调用函数前查询一下缓存,如果没有才调用了,有了就从缓存中返回值。一下子,这个递归从二叉树式的递归成了线性的递归。
这个例子没什么高深的,就是实用一些。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
importcProfile, pstats, StringIO
defprofiler(func):
defwrapper(*args,**kwargs):
datafn=func.__name__+".profile"# Name the data file
prof=cProfile.Profile()
retval=prof.runcall(func,*args,**kwargs)
#prof.dump_stats(datafn)
s=StringIO.StringIO()
sortby='cumulative'
ps=pstats.Stats(prof, stream=s).sort_stats(sortby)
ps.print_stats()
prints.getvalue()
returnretval
returnwrapper
|
下面这个示例展示了通过URL的路由来调用相关注册的函数示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
classMyApp():
def__init__(self):
self.func_map={}
defregister(self, name):
deffunc_wrapper(func):
self.func_map[name]=func
returnfunc
returnfunc_wrapper
defcall_method(self, name=None):
func=self.func_map.get(name,None)
iffuncisNone:
raiseException("No function registered against - "+str(name))
returnfunc()
app=MyApp()
@app.register('/')
defmain_page_func():
return"This is the main page."
@app.register('/next_page')
defnext_page_func():
return"This is the next page."
printapp.call_method('/')
printapp.call_method('/next_page')
|
注意:
1)上面这个示例中,用类的实例来做decorator。
2)decorator类中没有__call__(),但是wrapper返回了原函数。所以,原函数没有发生任何变化。
下面这个示例演示了一个logger的decorator,这个decorator输出了函数名,参数,返回值,和运行时间。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
fromfunctoolsimportwraps
deflogger(fn):
@wraps(fn)
defwrapper(*args,**kwargs):
ts=time.time()
result=fn(*args,**kwargs)
te=time.time()
print"function = {0}".format(fn.__name__)
print" arguments = {0} {1}".format(args, kwargs)
print" return = {0}".format(result)
print" time = %.6f sec"%(te-ts)
returnresult
returnwrapper
@logger
defmultipy(x, y):
returnx*y
@logger
defsum_num(n):
s=0
foriinxrange(n+1):
s+=i
returns
printmultipy(2,10)
printsum_num(100)
printsum_num(10000000)
|
上面那个打日志还是有点粗糙,让我们看一个更好一点的(带log level参数的):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
importinspect
defget_line_number():
returninspect.currentframe().f_back.f_back.f_lineno
deflogger(loglevel):
deflog_decorator(fn):
@wraps(fn)
defwrapper(*args,**kwargs):
ts=time.time()
result=fn(*args,**kwargs)
te=time.time()
print"function = "+fn.__name__,
print" arguments = {0} {1}".format(args, kwargs)
print" return = {0}".format(result)
print" time = %.6f sec"%(te-ts)
if(loglevel=='debug'):
print" called_from_line : "+str(get_line_number())
returnresult
returnwrapper
returnlog_decorator
|
但是,上面这个带log level参数的有两具不好的地方,
1) loglevel不是debug的时候,还是要计算函数调用的时间。
2) 不同level的要写在一起,不易读。
我们再接着改进:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
importinspect
defadvance_logger(loglevel):
defget_line_number():
returninspect.currentframe().f_back.f_back.f_lineno
def_basic_log(fn, result,*args,**kwargs):
print"function = "+fn.__name__,
print" arguments = {0} {1}".format(args, kwargs)
print" return = {0}".format(result)
definfo_log_decorator(fn):
@wraps(fn)
defwrapper(*args,**kwargs):
result=fn(*args,**kwargs)
_basic_log(fn, result, args, kwargs)
returnwrapper
defdebug_log_decorator(fn):
@wraps(fn)
defwrapper(*args,**kwargs):
ts=time.time()
result=fn(*args,**kwargs)
te=time.time()
_basic_log(fn, result, args, kwargs)
print" time = %.6f sec"%(te-ts)
print" called_from_line : "+str(get_line_number())
returnwrapper
ifloglevelis"debug":
returndebug_log_decorator
else:
returninfo_log_decorator
|
你可以看到两点,
1)我们分了两个log level,一个是info的,一个是debug的,然后我们在外尾根据不同的参数返回不同的decorator。
2)我们把info和debug中的相同的代码抽到了一个叫_basic_log的函数里,DRY原则。
下面这个decorator是我在工作中用到的代码,我简化了一下,把DB连接池的代码去掉了,这样能简单点,方便阅读。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
|
importumysql
fromfunctoolsimportwraps
classConfiguraion:
def__init__(self, env):
ifenv=="Prod":
self.host ="coolshell.cn"
self.port =3306
self.db ="coolshell"
self.user ="coolshell"
self.passwd ="fuckgfw"
elifenv=="Test":
self.host ='localhost'
self.port =3300
self.user ='coolshell'
self.db ='coolshell'
self.passwd='fuckgfw'
defmysql(sql):
_conf=Configuraion(env="Prod")
defon_sql_error(err):
printerr
sys.exit(-1)
defhandle_sql_result(rs):
ifrs.rows >0:
fieldnames=[f[0]forfinrs.fields]
return[dict(zip(fieldnames, r))forrinrs.rows]
else:
return[]
defdecorator(fn):
@wraps(fn)
defwrapper(*args,**kwargs):
mysqlconn=umysql.Connection()
mysqlconn.settimeout(5)
mysqlconn.connect(_conf.host, _conf.port, _conf.user, \
_conf.passwd, _conf.db,True,'utf8')
try:
rs=mysqlconn.query(sql, {})
exceptumysql.Error as e:
on_sql_error(e)
data=handle_sql_result(rs)
kwargs["data"]=data
result=fn(*args,**kwargs)
mysqlconn.close()
returnresult
returnwrapper
returndecorator
@mysql(sql="select * from coolshell")
defget_coolshell(data):
... ...
... ..
|
下面量个非常简单的异步执行的decorator,注意,异步处理并不简单,下面只是一个示例。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
fromthreadingimportThread
fromfunctoolsimportwraps
defasync(func):
@wraps(func)
defasync_func(*args,**kwargs):
func_hl=Thread(target=func, args=args, kwargs=kwargs)
func_hl.start()
returnfunc_hl
returnasync_func
if__name__=='__main__':
fromtimeimportsleep
@async
defprint_somedata():
print'starting print_somedata'
sleep(2)
print'print_somedata: 2 sec passed'
sleep(2)
print'print_somedata: 2 sec passed'
sleep(2)
print'finished print_somedata'
defmain():
print_somedata()
print'back in main'
print_somedata()
print'back in main'
main()
|
# -*- coding: UTF-8 -*- ''' @summary: 验证器 该模块提供了一个装饰器用于验证参数是否合法,使用方法为: from validator import validParam, nullOk, multiType @validParam(i=int) def foo(i): return i+1 编写验证器: 1. 仅验证类型: @validParam(type, ...) 例如: 检查第一个位置的参数是否为int类型: @validParam(int) 检查名为x的参数是否为int类型: @validParam(x=int) 验证多个参数: @validParam(int, int) 指定参数名验证: @validParam(int, s=str) 针对*和**参数编写的验证器将验证这些参数实际包含的每个元素: @validParam(varargs=int) def foo(*varargs): pass @validParam(kws=int) def foo7(s, **kws): pass 2. 带有条件的验证: @validParam((type, condition), ...) 其中,condition是一个表达式字符串,使用x引用待验证的对象; 根据bool(表达式的值)判断是否通过验证,若计算表达式时抛出异常,视为失败。 例如: 验证一个10到20之间的整数: @validParam(i=(int, '10<x<20')) 验证一个长度小于20的字符串: @validParam(s=(str, 'len(x)<20')) 验证一个年龄小于20的学生: @validParam(stu=(Student, 'x.age<20')) 另外,如果类型是字符串,condition还可以使用斜杠开头和结尾表示正则表达式匹配。 验证一个由数字组成的字符串: @validParam(s=(str, '/^\d*$/')) 3. 以上验证方式默认为当值是None时验证失败。如果None是合法的参数,可以使用nullOk()。 nullOk()接受一个验证条件作为参数。 例如: @validParam(i=nullOk(int)) @validParam(i=nullOk((int, '10<x<20'))) 也可以简写为: @validParam(i=nullOk(int, '10<x<20')) 4. 如果参数有多个合法的类型,可以使用multiType()。 multiType()可接受多个参数,每个参数都是一个验证条件。 例如: @validParam(s=multiType(int, str)) @validParam(s=multiType((int, 'x>20'), nullOk(str, '/^\d+$/'))) 5. 如果有更复杂的验证需求,还可以编写一个函数作为验证函数传入。 这个函数接收待验证的对象作为参数,根据bool(返回值)判断是否通过验证,抛出异常视为失败。 例如: def validFunction(x): return isinstance(x, int) and x>0 @validParam(i=validFunction) def foo(i): pass 这个验证函数等价于: @validParam(i=(int, 'x>0')) def foo(i): pass @author: HUXI @since: 2011-3-22 @change: ''' import inspect import re class ValidateException(Exception): pass def validParam(*varargs, **keywords): '''验证参数的装饰器。''' varargs = map(_toStardardCondition, varargs) keywords = dict((k, _toStardardCondition(keywords[k])) for k in keywords) def generator(func): args, varargname, kwname = inspect.getargspec(func)[:3] dctValidator = _getcallargs(args, varargname, kwname, varargs, keywords) def wrapper(*callvarargs, **callkeywords): dctCallArgs = _getcallargs(args, varargname, kwname, callvarargs, callkeywords) k, item = None, None try: for k in dctValidator: if k == varargname: for item in dctCallArgs[k]: assert dctValidator[k](item) elif k == kwname: for item in dctCallArgs[k].values(): assert dctValidator[k](item) else: item = dctCallArgs[k] assert dctValidator[k](item) except: raise ValidateException,\ ('%s() parameter validation fails, param: %s, value: %s(%s)' % (func.func_name, k, item, item.__class__.__name__)) return func(*callvarargs, **callkeywords) wrapper = _wrapps(wrapper, func) return wrapper return generator def _toStardardCondition(condition): '''将各种格式的检查条件转换为检查函数''' if inspect.isclass(condition): return lambda x: isinstance(x, condition) if isinstance(condition, (tuple, list)): cls, condition = condition[:2] if condition is None: return _toStardardCondition(cls) if cls in (str, unicode) and condition[0] == condition[-1] == '/': return lambda x: (isinstance(x, cls) and re.match(condition[1:-1], x) is not None) return lambda x: isinstance(x, cls) and eval(condition) return condition def nullOk(cls, condition=None): '''这个函数指定的检查条件可以接受None值''' return lambda x: x is None or _toStardardCondition((cls, condition))(x) def multiType(*conditions): '''这个函数指定的检查条件只需要有一个通过''' lstValidator = map(_toStardardCondition, conditions) def validate(x): for v in lstValidator: if v(x): return True return validate def _getcallargs(args, varargname, kwname, varargs, keywords): '''获取调用时的各参数名-值的字典''' dctArgs = {} varargs = tuple(varargs) keywords = dict(keywords) argcount = len(args) varcount = len(varargs) callvarargs = None if argcount <= varcount: for n, argname in enumerate(args): dctArgs[argname] = varargs[n] callvarargs = varargs[-(varcount-argcount):] else: for n, var in enumerate(varargs): dctArgs[args[n]] = var for argname in args[-(argcount-varcount):]: if argname in keywords: dctArgs[argname] = keywords.pop(argname) callvarargs = () if varargname is not None: dctArgs[varargname] = callvarargs if kwname is not None: dctArgs[kwname] = keywords dctArgs.update(keywords) return dctArgs def _wrapps(wrapper, wrapped): '''复制元数据''' for attr in ('__module__', '__name__', '__doc__'): setattr(wrapper, attr, getattr(wrapped, attr)) for attr in ('__dict__',): getattr(wrapper, attr).update(getattr(wrapped, attr, {})) return wrapper #=============================================================================== # 测试 #=============================================================================== def _unittest(func, *cases): for case in cases: _functest(func, *case) def _functest(func, isCkPass, *args, **kws): if isCkPass: func(*args, **kws) else: try: func(*args, **kws) assert False except ValidateException: pass def _test1_simple(): #检查第一个位置的参数是否为int类型: @validParam(int) def foo1(i): pass _unittest(foo1, (True, 1), (False, 's'), (False, None)) #检查名为x的参数是否为int类型: @validParam(x=int) def foo2(s, x): pass _unittest(foo2, (True, 1, 2), (False, 's', 's')) #验证多个参数: @validParam(int, int) def foo3(s, x): pass _unittest(foo3, (True, 1, 2), (False, 's', 2)) #指定参数名验证: @validParam(int, s=str) def foo4(i, s): pass _unittest(foo4, (True, 1, 'a'), (False, 's', 1)) #针对*和**参数编写的验证器将验证这些参数包含的每个元素: @validParam(varargs=int) def foo5(*varargs): pass _unittest(foo5, (True, 1, 2, 3, 4, 5), (False, 'a', 1)) @validParam(kws=int) def foo6(**kws): pass _functest(foo6, True, a=1, b=2) _functest(foo6, False, a='a', b=2) @validParam(kws=int) def foo7(s, **kws): pass _functest(foo7, True, s='a', a=1, b=2) def _test2_condition(): #验证一个10到20之间的整数: @validParam(i=(int, '10<x<20')) def foo1(x, i): pass _unittest(foo1, (True, 1, 11), (False, 1, 'a'), (False, 1, 1)) #验证一个长度小于20的字符串: @validParam(s=(str, 'len(x)<20')) def foo2(a, s): pass _unittest(foo2, (True, 1, 'a'), (False, 1, 1), (False, 1, 'a'*20)) #验证一个年龄小于20的学生: class Student(object): def __init__(self, age): self.age=age @validParam(stu=(Student, 'x.age<20')) def foo3(stu): pass _unittest(foo3, (True, Student(18)), (False, 1), (False, Student(20))) #验证一个由数字组成的字符串: @validParam(s=(str, r'/^\d*$/')) def foo4(s): pass _unittest(foo4, (True, '1234'), (False, 1), (False, 'a1234')) def _test3_nullok(): @validParam(i=nullOk(int)) def foo1(i): pass _unittest(foo1, (True, 1), (False, 'a'), (True, None)) @validParam(i=nullOk(int, '10<x<20')) def foo2(i): pass _unittest(foo2, (True, 11), (False, 'a'), (True, None), (False, 1)) def _test4_multitype(): @validParam(s=multiType(int, str)) def foo1(s): pass _unittest(foo1, (True, 1), (True, 'a'), (False, None), (False, 1.1)) @validParam(s=multiType((int, 'x>20'), nullOk(str, '/^\d+$/'))) def foo2(s): pass _unittest(foo2, (False, 1), (False, 'a'), (True, None), (False, 1.1), (True, 21), (True, '21')) def _main(): d = globals() from types import FunctionType print for f in d: if f.startswith('_test'): f = d[f] if isinstance(f, FunctionType): f() if __name__ == '__main__': _main()