在控制变量实验中,我们通常需要固定住一个或几个参数,并遍历一个区间将参数代入实验中。
假设一个叫func的函数有三个参数,第一个参数固定,第二、三个参数是需要控制变量的,那么在参数离散取值的情况下,罗列出第二、第三个参数所有的情况,就需要用排列组合将不同参数代入func进而得到实验结果。本人实现了这样一个类,现在开源给大家使用。希望对大家有帮助
# coding:utf8
import time
import copy
import pandas as pd
class Experimenter:
def __init__(self, verbose, log_filename='explog_{}.log'):
self.statistic = []
self.verbose = verbose
self.log_filename = log_filename
def statistic_to_csv(self):
k = [i for i in self.statistic[0]]
data = []
for record in self.statistic:
sub_data = [record[i] for i in k]
data.append(sub_data)
df = pd.DataFrame(data)
df.columns = k
df.to_csv(self.log_filename.format(time.strftime('%Y-%m-%d %H-%M-%S', time.gmtime(time.time()+8*60*60))))
@staticmethod
def arg_factory(args, scope):
def inc(idx):
scope_arg_index[idx] += 1
if scope_arg_index[idx] >= scope_arg_length[idx]:
scope_arg_index[idx] = 0
if idx != scope_arg_num-1:
inc(idx+1)
return 0
else:
return 0
else:
return 0
scope_arg_length = [len(scope[i]) for i in scope]
scope_arg_index = [0 for _ in scope]
scope_arg_num = len(scope)
while True:
_tmp_args = copy.deepcopy(args)
_idx = 0
for ar in scope:
_tmp_args[ar] = scope[ar][scope_arg_index[_idx]]
_idx += 1
yield _tmp_args
FLAG = False
for _i in range(scope_arg_num):
if scope_arg_length[_i] - 1 > scope_arg_index[_i]:
FLAG = True
if not FLAG:
break
inc(0)
def run(self, execute_func, args={
}, args_scope={
}): # args中,需要尝试的args,留空,放入args_scope中
for k in args:
if k is None:
assert k in args_scope, '{} 参数未被传入!!!'.format(k)
get_next_args = lambda x : x.next()
arg_fac = self.arg_factory(args, args_scope)
while True:
try:
current_args = get_next_args(arg_fac)
if self.verbose:
print(current_args)
res = execute_func(**current_args)
assert isinstance(res, dict), 'func result must be dict!'
for k in current_args:
if k in res:
res['_'+k+'_'] = current_args[k]
else:
res[k] = current_args[k]
self.statistic.append(res)
if self.verbose:
print('-' * 150)
print('-' * 150)
print('完成一个任务。')
except:
if self.verbose:
print('所有参数被执行完毕,任务结束~')
break
self.statistic_to_csv()
以上是类,下面介绍用法:
第一步:实例化(需指定verbose、log_filepath)
第二步:传入做实验的用的函数(注意函数不要带括号)、实验函数的参数dict(所有关键字参数都要包含在键中,固定参数直接传入值,需改变参数写None)、实验的变量dict(用键值对形式,关键字参数为键,需要更改的值放在一个列表中)
下面用一个实例来说明,很简单,一看就会了。
假设要改变func的b、c参数的值进行实验:
def func(a, b, c):
return {
'a+b':a+b, 'a+c': a+c}
exp = Experimenter(verbose=False)
exp.run(execute_func=func, args={
'a': 1, 'b':None, 'c':None}, args_scope={
'b':[1, 2, 3], 'c':[4, 5, 6]})
运行后,即可在同目录下自动生成的csv文件中查看结果:
a+c,a+b,c,b,a
0,5,2,4,1,1
1,6,2,5,1,1
2,7,2,6,1,1
3,5,3,4,2,1
4,6,3,5,2,1
5,7,3,6,2,1
6,5,4,4,3,1
7,6,4,5,3,1
8,7,4,6,3,1
注:实验函数的返回值必须是dict,空dict也可以。
希望能帮到大家~