在所有可取的
----------------------------------------------------------------------------------------------------------------------------------------
全局设置:
# svmtrain and gnuplot executable
is_win32 = (sys.platform == 'win32')
if not is_win32:
svmtrain_exe = "../svm-train"
gnuplot_exe = "/usr/bin/gnuplot"
else:
# example for windows
svmtrain_exe = r"..\windows\svm-train.exe"
# windows平台需要修改程序 gnuplot 的路径
gnuplot_exe = r"c:\Program Files (x86)\gnuplot\bin\pgnuplot.exe"
# global parameters and their default values
# 全局变量,同时设置默认的参数
# 5-折交叉验证
fold = 5
# 参数 c和g 的范围 begin~end 以及步长 step
c_begin, c_end, c_step = -5, 15, 2
g_begin, g_end, g_step = 3, -15, -2
# dataset_pathname:训练数据路径,dataset_title:训练数据文件名,pass_through_string:用户给定的 svm 分类器的参数
global dataset_pathname, dataset_title, pass_through_string
# out_filename:文件名,该文件记录交叉验证过程中计算得到的每组数据(c,g,rate等),png_filename:图片的文件名,将 out_filename 中的数据画成二维图形
global out_filename, png_filename
# experimental
# 设置 telnet 和 ssh 的节点名称--- grid.py 的网格搜索可以并行化
telnet_workers = []
ssh_workers = []
# 本地节点数量
nr_local_worker = 1
----------------------------------------------------------------------------------------------------------------------------------------
处理命令行传入的参数:
# process command line options, set global parameters
def process_options(argv=sys.argv):
# 将这些变量设为全局的,免去函数间的传参
global fold
global c_begin, c_end, c_step
global g_begin, g_end, g_step
global dataset_pathname, dataset_title, pass_through_string
# gnuplot 是画图的句柄,要随时将得到的每组数据(c,g,rate等)画在图上
global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
# 命令提示
usage = """\
Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
[-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
[additional parameters for svm-train] dataset"""
# 命令使用错误提示
if len(argv) < 2:
print(usage)
sys.exit(1)
# 变量赋值
dataset_pathname = argv[-1]
dataset_title = os.path.split(dataset_pathname)[1]
out_filename = '{0}.out'.format(dataset_title)
png_filename = '{0}.png'.format(dataset_title)
# 传给 svmtrain_exe 的参数列表
pass_through_options = []
i = 1
while i < len(argv) - 1:
# 用户给定的 关于c 的3个参数:范围 begin~end 以及步长 step
if argv[i] == "-log2c":
i = i + 1
(c_begin,c_end,c_step) = map(float,argv[i].split(","))
# 用户给定的 关于g 的3个参数:范围 begin~end 以及步长 step
elif argv[i] == "-log2g":
i = i + 1
(g_begin,g_end,g_step) = map(float,argv[i].split(","))
# 用户给定的交叉验证的折数
elif argv[i] == "-v":
i = i + 1
fold = argv[i]
# 参数传递方法有误,给出提示
elif argv[i] in ('-c','-g'):
print("Option -c and -g are renamed.")
print(usage)
sys.exit(1)
# 用户给定的 svmtrain_exe 的路径(使用svmtrain_exe 来计算某个
elif argv[i] == '-svmtrain':
i = i + 1
svmtrain_exe = argv[i]
# 用户给定的 gnuplot_exe 的路径
elif argv[i] == '-gnuplot':
i = i + 1
gnuplot_exe = argv[i]
# 用户指定记录结果的文件名
elif argv[i] == '-out':
i = i + 1
out_filename = argv[i]
# 用户指定记录结果的图片文件名
elif argv[i] == '-png':
i = i + 1
png_filename = argv[i]
# 用户给定的 svmtrain_exe 的参数
else:
pass_through_options.append(argv[i])
i = i + 1
# svmtrain_exe 以字符串的形式接收参数
pass_through_string = " ".join(pass_through_options)
# 检查程序svmtrain_exe 、gnuplot_exe 和训练数据文件dataset_pathname 是否存在
assert os.path.exists(svmtrain_exe),"svm-train executable not found"
assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
assert os.path.exists(dataset_pathname),"dataset not found"
# 初始化画图句柄 gnuplot(python 连接、调用 gnuplot_exe 程序 )
gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin
----------------------------------------------------------------------------------------------------------------------------------------
产生需要计算的(c,g)对:
函数 calculate_jobs 调用函数range_f 和 permute_sequence,将 c 和 g 的每一对可取的值组合起来,例如,根据默认值,c 可取 [-5, -3, -1, 1, 3, 5, 7, 9, 11, 13, 15] ,g 可取 [3, 1, -1, -3, -5, -7, -9, -11, -13, -15],则产生 110个(c,g)对:
(5, -7)
(-1, -7)
(5, -1) (-1, -1)
(11, -7) (11, -1)
(5, -13) (-1, -13) (11, -13)
(-3, -7) (-3, -1) (-3, -13)
(5, 1) (-1, 1) (11, 1) (-3, 1)
(9, -7) (9, -1) (9, -13) (9, 1)
(5, -11) (-1, -11) (11, -11) (-3, -11) (9, -11)
(3, -7) (3, -1) (3, -13) (3, 1) (3, -11)
(5, -5) (-1, -5) (11, -5) (-3, -5) (9, -5) (3, -5)
(15, -7) (15, -1) (15, -13) (15, 1) (15, -11) (15, -5)
(5, -15) (-1, -15) (11, -15) (-3, -15) (9, -15) (3, -15) (15, -15)
(-5, -7) (-5, -1) (-5, -13) (-5, 1) (-5, -11) (-5, -5) (-5, -15)
(5, 3) (-1, 3) (11, 3) (-3, 3) (9, 3) (3, 3) (15, 3) (-5, 3)
(7, -7) (7, -1) (7, -13) (7, 1) (7, -11) (7, -5) (7, -15) (7, 3)
(5, -9) (-1, -9) (11, -9) (-3, -9) (9, -9) (3, -9) (15, -9) (-5, -9) (7, -9)
(1, -7) (1, -1) (1, -13) (1, 1) (1, -11) (1, -5) (1, -15) (1, 3) (1, -9)
(5, -3) (-1, -3) (11, -3) (-3, -3) (9, -3) (3, -3) (15, -3) (-5, -3) (7, -3) (1, -3)
(13, -7) (13, -1) (13, -13) (13, 1) (13, -11) (13, -5) (13, -15) (13, 3) (13, -9) (13, -3)
grid.py 计算这 110 个
def range_f(begin,end,step):
# like range, but works on non-integer too
seq = []
while True:
if step > 0 and begin > end: break
if step < 0 and begin < end: break
seq.append(begin)
begin = begin + step
return seq
def permute_sequence(seq):
n = len(seq)
if n <= 1: return seq
mid = int(n/2)
left = permute_sequence(seq[:mid])
right = permute_sequence(seq[mid+1:])
ret = [seq[mid]]
while left or right:
if left: ret.append(left.pop(0))
if right: ret.append(right.pop(0))
return ret
def calculate_jobs():
c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
nr_c = float(len(c_seq))
nr_g = float(len(g_seq))
i = 0
j = 0
jobs = []
while i < nr_c or j < nr_g:
if i/nr_c < j/nr_g:
# increase C resolution
line = []
for k in range(0,j):
line.append((c_seq[i],g_seq[k]))
i = i + 1
jobs.append(line)
else:
# increase g resolution
line = []
for k in range(0,i):
line.append((c_seq[k],g_seq[j]))
j = j + 1
jobs.append(line)
return jobs
----------------------------------------------------------------------------------------------------------------------------------------
根据给定的(c,g)对,执行交叉验证,计算准确率rate:
class Worker(Thread):
# 线程初始化
def __init__(self,name,job_queue,result_queue):
Thread.__init__(self)
# 工作节点名称
self.name = name
# 分配给当前工作节点的任务
self.job_queue = job_queue
# 当前工作节点的计算结果表
self.result_queue = result_queue
# 启动线程( main函数中的LocalWorker(...).start()就是调用的这个函数 )
def run(self):
while True:
# 从自己的工作队列中取出一个任务
(cexp,gexp) = self.job_queue.get()
# 工作节点被停止了
if cexp is WorkerStopToken:
self.job_queue.put((cexp,gexp))
# print('worker {0} stop.'.format(self.name))
break
try:
# 调用 run_one 函数,计算准确率
rate = self.run_one(2.0**cexp,2.0**gexp)
# 计算不成功
if rate is None: raise RuntimeError("get no rate")
except:
# we failed, let others do that and we just quit
traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
self.job_queue.put((cexp,gexp))
print('worker {0} quit.'.format(self.name))
break
# 计算成功,将得到的结果放入结果表 result_queue 中
else:
self.result_queue.put((self.name,cexp,gexp,rate))
class LocalWorker(Worker):
def run_one(self,c,g):
# 调用程序 svmtrain_exe 的命令
cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
(svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)
# 调用程序 svmtrain_exe
result = Popen(cmdline,shell=True,stdout=PIPE).stdout
# 处理计算结果 line,字符串 line 中的字段以空格分离,最后一个字段为准确率,函数返回这个数值
for line in result.readlines():
if str(line).find("Cross") != -1:
return float(line.split()[-1][0:-1])
----------------------------------------------------------------------------------------------------------------------------------------
主函数:
def main():
# set parameters
# 处理命令行参数
process_options()
# put jobs in queue
# 计算任务(所有 c 和 g 的组合)
jobs = calculate_jobs()
# 任务队列和结果队列
job_queue = Queue.Queue(0)
result_queue = Queue.Queue(0)
# 将任务放入任务队列中
for line in jobs:
for (c,g) in line:
job_queue.put((c,g))
# 意思大概为:将队列转换为栈,原因是:如果某个节点计算出现问题,那么可以再将未计算成功的任务(
# hack the queue to become a stack --
# this is important when some thread
# failed and re-put a job. It we still
# use FIFO, the job will be put
# into the end of the queue, and the graph
# will only be updated in the end
job_queue._put = job_queue.queue.appendleft
# fire telnet workers
# 如果有 telnet 节点
if telnet_workers:
nr_telnet_worker = len(telnet_workers)
username = getpass.getuser()
password = getpass.getpass()
for host in telnet_workers:
TelnetWorker(host,job_queue,result_queue,
host,username,password).start()
# fire ssh workers
# 如果有 ssh节点
if ssh_workers:
for host in ssh_workers:
SSHWorker(host,job_queue,result_queue,host).start()
# fire local workers
# 启动本地节点线程
for i in range(nr_local_worker):
LocalWorker('local',job_queue,result_queue).start()
# gather results
# 记录处理过的计算结果
done_jobs = {}
# 结果文件
result_file = open(out_filename, 'w')
# 传递给 gnuplot_exe 程序的数据
db = []
# 最佳准确率以及最佳准确率下的c和g值
best_rate = -1
best_c1,best_g1 = None,None
# 遍历任务栈
for line in jobs:
for (c,g) in line:
# 若(c, g) 的结果还未处理
while (c, g) not in done_jobs:
# 取出一个计算结果
(worker,c1,g1,rate) = result_queue.get()
# 处理计算结果
done_jobs[(c1,g1)] = rate
# 写入结果文件
result_file.write('{0} {1} {2}\n'.format(c1,g1,rate))
result_file.flush()
# 若 准确率更高 或者 相同准确率下具有更小的c值,则对最佳rate和相应的c、g进行更新
if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1
# log2(best_c)=best_c1, log2(best_g)=best_g1
best_c1,best_g1=c1,g1
best_c = 2.0**c1
best_g = 2.0**g1
# 打印计算结果
print("[{0}] {1} {2} {3} (best c={4}, g={5}, rate={6})".format \
(worker,c1,g1,rate, best_c, best_g, best_rate))
# 记录计算结果
db.append((c,g,done_jobs[(c,g)]))
# 画图
redraw(db,[best_c1, best_g1, best_rate])
redraw(db,[best_c1, best_g1, best_rate],True)
# 停止节点的工作
job_queue.put((WorkerStopToken,None))
# 打印最佳准确率下的c、g值以及最佳准确率
print("{0} {1} {2}".format(best_c, best_g, best_rate))