这里重点讲一下 _ _all_ _ = ['find_parameters'] :
_all__ = ['find_parameters']
是 Python 中用于定义模块级别的变量 __all__
的语法, __all__
是一个包含模块中应该被公开(即可以通过 from module import *
导入)的变量名的列表
__all__
是一个约定俗成的变量名,用于指定在使用 from module import *
语句时,应该导入哪些变量名。这样可以控制模块的命名空间,避免不必要的变量污染。
['find_parameters']
是一个包含在 __all__
中的列表,其中包含了模块中应该被导入的变量名。在这个例子中,只有一个变量名 find_parameters
被包含在 __all__
中。
通过这个设置,当其他模块使用 from module import *
导入这个模块时,只有 find_parameters
这个变量名会被导入,其他未在 __all__
中指定的变量不会被导入。这是一种良好的编程实践,因为它可以提供更清晰的模块接口,避免不必要的命名冲突和变量污染。
构造函数接收两个参数:dataset_pathname 和 options
根据操作系统设置svm-train.exe和gnuplot.exe 的路径,这个要根据自己系统的实际按照情况 来进行路径的设置。
默认参数的设置以及解析传入参数的函数parse_options。
最后,检查 SVM 训练可执行文件路径、数据集路径和 Gnuplot 可执行文件路径的存在性。
class GridOption:
'''
构造函数 __init__:
接收两个参数 dataset_pathname 和 options
dataset_pathname 是数据集的路径
options 是一个包含其他配置选项的字典
获取当前脚本所在目录,并根据操作系统设置 svmtrain_pathname 和 gnuplot_pathname
'''
def __init__(self, dataset_pathname, options):
dirname = os.path.dirname(__file__)
# 使用 sys.platform 来检查操作系统
# 如果不是 Windows (sys.platform != 'win32'),则设置 svmtrain_pathname 为在当前脚本所在目录下的 ‘…/svm-train’,并设置 gnuplot_pathname 为 ‘/usr/bin/gnuplot’
if sys.platform != 'win32':
self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
self.gnuplot_pathname = '/usr/bin/gnuplot'
else:
# example for windows
# 如果是 Windows,则设置 svmtrain_pathname 为在当前脚本所在目录下的 r’…\windows\svm-train.exe’,并设置 gnuplot_pathname 为 r’c:\tmp\gnuplot\binary\pgnuplot.exe’
self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
# svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
# 默认参数的设置
# 设置了一系列参数的默认值,例如 fold、c_begin、c_end、c_step、g_begin、g_end、g_step 等,用于定义网格搜索的参数范围和步长
# 设置了 grid_with_c 和 grid_with_g 为 True,表示要在网格搜索中搜索 C 和 gamma 参数
self.fold = 5
self.c_begin, self.c_end, self.c_step = -5, 15, 2
self.g_begin, self.g_end, self.g_step = 3, -15, -2
self.grid_with_c, self.grid_with_g = True, True
self.dataset_pathname = dataset_pathname # 将传入的 dataset_pathname 赋值给 self.dataset_pathname
self.dataset_title = os.path.split(dataset_pathname)[1] # 提取数据集的标题部分,通过 os.path.split(dataset_pathname) 和 [1] 获取,赋值给 self.dataset_title
self.out_pathname = '{0}.out'.format(self.dataset_title) # 设置 out_pathname 为 ‘{0}.out’,其中 {0} 是数据集标题
self.png_pathname = '{0}.png'.format(self.dataset_title) # 设置 png_pathname 为 ‘{0}.png’,其中 {0} 是数据集标题
self.pass_through_string = ' ' # 设置 pass_through_string 为一个空格
self.resume_pathname = None # 设置 resume_pathname 为 None
self.parse_options(options) # 调用 parse_options 方法,该方法用于解析传入的选项,并更新类的属性值
# 定义了 parse_options 方法,该方法用于解析传入的选项列表,更新 GridOption 类的属性值
def parse_options(self, options):
# options 是传入的选项,可以是字符串,也可以是由字符串组成的列表
# 如果 options 是字符串,通过 options.split() 将其分割成列表
if type(options) == str:
options = options.split()
i = 0 # 初始化变量 i 为 0,用于迭代 options 列表
# 初始化空列表 pass_through_options,用于存储未被解析的选项
pass_through_options = []
# 使用 while 循环遍历 options 列表
# 通过检查当前选项,更新相应的 GridOption 类属性
while i < len(options):
'''
-log2c 和 -log2g:解析参数范围和步长,如果值为 'null',则相应的网格搜索标志设为 False
-v:设置交叉验证的折数
-c 和 -g:抛出错误,提示使用 -log2c 和 -log2g
-svmtrain:设置 SVM 训练可执行文件路径
-gnuplot:设置 Gnuplot 可执行文件路径,如果值为 'null',则设为 None
-out:设置输出文件路径,如果值为 'null',则设为 None
-png:设置 PNG 文件路径
-resume:设置恢复训练的文件路径,如果未提供则使用默认文件名
'''
if options[i] == '-log2c':
i = i + 1
if options[i] == 'null':
self.grid_with_c = False
else:
self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
elif options[i] == '-log2g':
i = i + 1
if options[i] == 'null':
self.grid_with_g = False
else:
self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
elif options[i] == '-v':
i = i + 1
self.fold = options[i]
elif options[i] in ('-c','-g'):
raise ValueError('Use -log2c and -log2g.')
elif options[i] == '-svmtrain':
i = i + 1
self.svmtrain_pathname = options[i]
elif options[i] == '-gnuplot':
i = i + 1
if options[i] == 'null':
self.gnuplot_pathname = None
else:
self.gnuplot_pathname = options[i]
elif options[i] == '-out':
i = i + 1
if options[i] == 'null':
self.out_pathname = None
else:
self.out_pathname = options[i]
elif options[i] == '-png':
i = i + 1
self.png_pathname = options[i]
elif options[i] == '-resume':
if i == (len(options)-1) or options[i+1].startswith('-'):
self.resume_pathname = self.dataset_title + '.out'
else:
i = i + 1
self.resume_pathname = options[i]
else:
pass_through_options.append(options[i]) # 未识别的选项将被添加到 pass_through_options 列表中
i = i + 1
# 使用 ' '.join(pass_through_options) 将未识别的选项组合成一个字符串,更新 pass_through_string 属性
self.pass_through_string = ' '.join(pass_through_options)
# 检查 SVM 训练可执行文件路径、数据集路径和 Gnuplot 可执行文件路径的存在性
if not os.path.exists(self.svmtrain_pathname):
raise IOError('svm-train executable not found')
if not os.path.exists(self.dataset_pathname):
raise IOError('dataset not found')
if self.resume_pathname and not os.path.exists(self.resume_pathname):
raise IOError('file for resumption not found') # 如果 resume_pathname 存在,检查其存在性
if not self.grid_with_c and not self.grid_with_g: # 如果同时设置了 -log2c 和 -log2g 为 False,抛出错误
raise ValueError('-log2c and -log2g should not be null simultaneously')
if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
# 如果 Gnuplot 可执行文件不存在,输出错误信息并将其设为 None
sys.stderr.write('gnuplot executable not found\n')
self.gnuplot_pathname = None
补充:“win32” 是 Windows 操作系统的平台标识符。在 Python 中,sys.platform
返回一个字符串,表示当前运行 Python 解释器的平台。对于 Windows 系统,这个字符串通常是"win32"。所以,if sys.platform != 'win32'
这个条件语句检查当前操作系统是否为 Windows 之外的其他操作系统。
def redraw(db,best_param,gnuplot,options,tofile=False):
if len(db) == 0: return
begin_level = round(max(x[2] for x in db)) - 3
step_size = 0.5
best_log2c,best_log2g,best_rate = best_param
# if newly obtained c, g, or cv values are the same,
# then stop redrawing the contour.
if all(x[0] == db[0][0] for x in db): return
if all(x[1] == db[0][1] for x in db): return
if all(x[2] == db[0][2] for x in db): return
if tofile:
gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
#gnuplot.write(b"set term postscript color solid\n")
#gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
elif sys.platform == 'win32':
gnuplot.write(b"set term windows\n")
else:
gnuplot.write( b"set term x11\n")
gnuplot.write(b"set xlabel \"log2(C)\"\n")
gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
gnuplot.write(b"set contour\n")
gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
gnuplot.write(b"unset surface\n")
gnuplot.write(b"unset ztics\n")
gnuplot.write(b"set view 0,0\n")
gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
gnuplot.write(b"unset label\n")
gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
at screen 0.5,0.85 center\n". \
format(best_log2c, best_log2g, best_rate).encode())
gnuplot.write("set label \"C = {0} gamma = {1}\""
" at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
gnuplot.write(b"set key at screen 0.9,0.9\n")
gnuplot.write(b"splot \"-\" with lines\n")
db.sort(key = lambda x:(x[0], -x[1]))
prevc = db[0][0]
for line in db:
if prevc != line[0]:
gnuplot.write(b"\n")
prevc = line[0]
gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
gnuplot.write(b"e\n")
gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
gnuplot.flush()
该函数接受一个参数 options
,并返回两个值:jobs
和 resumed_jobs
,同时里面嵌套定义了函数 range_f 和函数 permute_sequence。
函数的主要目的是生成一系列的任务(jobs
),每个任务是一个参数组合,用于训练支持向量机(SVM)。这些参数是通过对给定的一组参数范围进行排列组合得到的。
def calculate_jobs(options):
def range_f(begin,end,step):
# like range, but works on non-integer too
seq = []
while True:
if step > 0 and begin > end: break
if step < 0 and begin < end: break
seq.append(begin)
begin = begin + step
return seq
def permute_sequence(seq):
n = len(seq)
if n <= 1: return seq
mid = int(n/2)
left = permute_sequence(seq[:mid])
right = permute_sequence(seq[mid+1:])
ret = [seq[mid]]
while left or right:
if left: ret.append(left.pop(0))
if right: ret.append(right.pop(0))
return ret
c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
if not options.grid_with_c:
c_seq = [None]
if not options.grid_with_g:
g_seq = [None]
nr_c = float(len(c_seq))
nr_g = float(len(g_seq))
i, j = 0, 0
jobs = []
while i < nr_c or j < nr_g:
if i/nr_c < j/nr_g:
# increase C resolution
line = []
for k in range(0,j):
line.append((c_seq[i],g_seq[k]))
i = i + 1
jobs.append(line)
else:
# increase g resolution
line = []
for k in range(0,i):
line.append((c_seq[k],g_seq[j]))
j = j + 1
jobs.append(line)
resumed_jobs = {}
if options.resume_pathname is None:
return jobs, resumed_jobs
for line in open(options.resume_pathname, 'r'):
line = line.strip()
rst = re.findall(r'rate=([0-9.]+)',line)
if not rst:
continue
rate = float(rst[0])
c, g = None, None
rst = re.findall(r'log2c=([0-9.-]+)',line)
if rst:
c = float(rst[0])
rst = re.findall(r'log2g=([0-9.-]+)',line)
if rst:
g = float(rst[0])
resumed_jobs[(c,g)] = rate
return jobs, resumed_jobs
range_f函数:
range_f
函数是一个自定义的函数,类似于内置函数 range
,但可以处理非整数的步长。它生成一个序列,从 begin
开始,以 step
为步长,直到不再满足条件。permute_sequence函数:
permute_sequence
函数用于对给定序列进行排列组合。它采用分而治之的方法,将序列分成两半,然后递归地对左右两半进行排列组合,最终将结果合并。参数生成:
range_f
函数生成了两个序列 c_seq
和 g_seq
,分别表示参数 c
和 g
的可能取值。如果选项 options.grid_with_c
或 options.grid_with_g
为 False,则相应的参数序列为单一值,即 [None]
。生成任务列表:
jobs
列表中。处理恢复任务:
options.resume_pathname
,则从该路径读取已经完成的任务信息,提取出参数组合和对应的性能率,并存储在 resumed_jobs
字典中。返回结果:
jobs
和已经完成的任务信息字典 resumed_jobs
。这段代码主要用于生成一系列参数组合,以及处理从先前运行中恢复的任务信息。这类功能通常在超参数搜索和模型训练中使用,以便系统能够自动尝试多种参数组合。
通常用作信号或标记,用于通信或控制多线程或多进程的执行流程。在这里, WorkerStopToken
的目的是作为一个简单的标记,用于通知工作线程停止或表示工作线程已经停止。在实际应用中,它可能会与其他线程或进程之间的通信机制一起使用,以实现协同工作或关闭。
class WorkerStopToken:
:定义了一个新的类,类名为 WorkerStopToken
pass
:在Python中,pass
是一个占位符语句,不执行任何操作。在这里,它被用作类的主体部分,表示这个类是一个空类,没有任何成员或方法。
Worker类继承自Python中的Thread类 ,这个类表示一个工作线程,用于执行支持向量机(SVM)的训练任务,该类定义了三个函数:_ _init_ _方法、run方法、get_cmd方法
class Worker(Thread):
def __init__(self,name,job_queue,result_queue,options):
Thread.__init__(self)
self.name = name
self.job_queue = job_queue
self.result_queue = result_queue
self.options = options
__init__
方法:
name
(线程名称)、job_queue
(任务队列)、result_queue
(结果队列)、options
(选项参数)
def run(self):
while True:
(cexp,gexp) = self.job_queue.get()
if cexp is WorkerStopToken:
self.job_queue.put((cexp,gexp))
# print('worker {0} stop.'.format(self.name))
break
try:
c, g = None, None
if cexp != None:
c = 2.0**cexp
if gexp != None:
g = 2.0**gexp
rate = self.run_one(c,g)
if rate is None: raise RuntimeError('get no rate')
except:
# we failed, let others do that and we just quit
traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
self.job_queue.put((cexp,gexp))
sys.stderr.write('worker {0} quit.\n'.format(self.name))
break
else:
self.result_queue.put((self.name,cexp,gexp,rate))
run
方法:
run
方法是 Thread
类的默认方法,在启动线程时会自动调用。这里是线程的主要执行逻辑。while True
) 从任务队列 (job_queue
) 获取任务,任务是 (cexp, gexp)
,其中 cexp
和 gexp
表示对应的参数指数。WorkerStopToken
,表示线程应该停止,将任务重新放回队列,并通过 break
退出循环,结束线程。c
和 g
,然后调用 run_one
方法执行具体的 SVM 训练,并获取性能率。sys.stderr.write
输出线程终止的信息,并通过 break
退出循环,结束线程。cexp
、gexp
和性能率放入结果队列 (result_queue
)。这段代码实现了一个工作线程的逻辑,用于执行 SVM 训练任务。它通过任务队列接收参数组合,执行训练,并将结果放入结果队列。这样的多线程结构通常用于加速大规模参数搜索和训练任务
def get_cmd(self,c,g):
options=self.options
cmdline = '"' + options.svmtrain_pathname + '"'
if options.grid_with_c:
cmdline += ' -c {0} '.format(c)
if options.grid_with_g:
cmdline += ' -g {0} '.format(g)
cmdline += ' -v {0} {1} {2} '.format\
(options.fold,options.pass_through_string,options.dataset_pathname)
return cmdline
get_cmd
方法:
-c
(如果启用)、参数 -g
(如果启用)、参数 -v
、折数、透传参数和数据集路径。下面我再来详细地讲解一下get_cmd方法 :
def get_cmd(self,c,g)
:
定义了一个方法 get_cmd
,接受两个参数 c
和 g
,表示 SVM 训练 的参数
options = self.options
将类实例中的 options
属性赋给局部变量 options
,以便在后续代码中使用
cmdline = '"' + options.svmtrain_pathname + '"'
构建命令行字符串的开头部分,包含 SVM 训练器的路径。使用双引号将路径括起来,以防止 路径中包含空格时出现问题。
if options.grid_with_c:
检查选项 grid_with_c
是否为真,即是否启用了参数 c
的网格搜索
cmdline += ' -c {0} '.format(c)
:
如果启用了参数 c
的网格搜索,则将参数 c
的值添加到命令行字符串中
if options.grid_with_g:
检查选项 grid_with_g
是否为真,即是否启用了参数 g
的网格搜索
cmdline += ' -g {0} '.format(g)
:
如果启用了参数 g
的网格搜索,则将参数 g
的值添加到命令行字符串中
cmdline += ' -v {0} {1} {2} '.format(options.fold, options.pass_through_string, options.dataset_pathname)
添加 SVM 训练的其他参数,包括:
-v
:表示要进行交叉验证{0}
:使用 options.fold
指定的折数{1}
:用户传递的额外参数{2}
:数据集的路径,由 options.dataset_pathname
指定return cmdline
:
返回构建好的 SVM 训练命令行字符串
总体而言,这段代码的作用是根据给定的参数 c
和 g
以及一些配置选项,生成用于执行 SVM 训练的命令行字符串。生成的命令行包括 SVM 训练器的路径、参数 -c
(如果启用)、参数 -g
(如果启用)、参数 -v
、交叉验证的折数、额外参数和数据集的路径。
定义了一个名为 LocalWorker
的类,它继承自先前提到的 Worker
类,并重写了 run_one
方法
class LocalWorker(Worker):
def run_one(self,c,g):
cmdline = self.get_cmd(c,g)
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
for line in result.readlines():
if str(line).find('Cross') != -1:
return float(line.split()[-1][0:-1])
run_one方法
该方法接受两个参数 c
和 g
,表示 SVM 训练的参数
cmdline = self.get_cmd(c,g)
:
调用父类 Worker
的 get_cmd
方法,获取 SVM 训练的命令行字符串,并将其赋给 cmdline
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
:
使用 subprocess.Popen
创建一个新的进程,运行 SVM 训练的命令行,其中
cmdline
是要执行的命令行字符串shell=True
表示使用系统的 shell 执行命令stdout=PIPE
表示将命令的标准输出捕获到 result
变量中stderr=PIPE
表示将命令的标准错误捕获,但在这段代码中没有使用stdin=PIPE
表示标准输入连接到管道,但在这段代码中没有使用for line in result.readlines():
遍历命令的标准输出的每一行
if str(line).find('Cross') != -1:
判断当前行是否包含字符串 ‘Cross’。如果包含,说明这一行包含了交叉验证的结果信息
return float(line.split()[-1][0:-1])
如果找到包含 ‘Cross’ 的行,提取该行的最后一个单词,去掉末尾的换行符,并将其转换为浮点 数。这个值表示 SVM 训练的性能率。
总体而言,这段代码实现了在本地环境运行 SVM 训练任务的逻辑。它通过创建新的进程执行 SVM 训练命令行,并从命令的标准输出中提取包含交叉验证结果的行,最终返回性能率作为结果。
class SSHWorker(Worker):
def __init__(self,name,job_queue,result_queue,host,options):
Worker.__init__(self,name,job_queue,result_queue,options)
self.host = host
self.cwd = os.getcwd()
def run_one(self,c,g):
cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
(self.host,self.cwd,self.get_cmd(c,g))
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
for line in result.readlines():
if str(line).find('Cross') != -1:
return float(line.split()[-1][0:-1])
定义了一个名为 SSHWorker
的类,它同样继承自之前提到的 Worker
类,并进行了一些定制化。
该类定义了初始化函数和run_one函数
__init__方法
初始化方法,除了调用父类的初始化方法外,还接受一个额外的参数 host
,表示远程主机的地 址。
self.host = host
:将传入的 host
参数保存为实例变量,以便在后续代码中使用
self.cwd = os.getcwd()
:获取当前工作目录,并保存为实例变量 cwd
run_one方法
重写了 run_one
方法,该方法接受两个参数 c
和 g
,表示 SVM 训练的参数
cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"' .format (self.host, self.cwd, self.get_cmd(c,g))
:
构建了一个 SSH 命令行字符串,该命令行用于在远程主机上执行 SVM 训练任务
ssh -x -t -t
:表示使用 SSH 连接,并在远程主机上执行命令
{0}
:用传入的 host
替换占位符,表示远程主机的地址
"cd {1}; {2}"
:在远程主机上执行的命令,首先切换到当前工作目录(cwd
),然后执行通过
调 用 get_cmd
方法生成的 SVM 训练命令
result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout:
使用 subprocess.Popen
创建一个新的进程,运行 SSH 命令行
cmdline
是要执行的 SSH 命令行字符串stdout=PIPE
表示将命令的标准输出捕获到 result
变量中for line in result.readlines():
遍历命令的标准输出的每一行
if str(line).find('Cross') != -1:
判断当前行是否包含字符串 ‘Cross’。如果包含,说明这一行包含了交叉验证的结果信息
return float(line.split()[-1][0:-1])
:
如果找到包含 ‘Cross’ 的行,提取该行的最后一个单词,去掉末尾的换行符,并将其转换为浮点数。这个值表示在远程主机上运行 SVM 训练的性能率
总体而言,这段代码实现了在远程主机上通过 SSH 运行 SVM 训练任务的逻辑。它构建了相应的 SSH 命令行,执行远程任务,并从命令的标准输出中提取包含交叉验证结果的行,最终返回性能率作为结果。
class TelnetWorker(Worker):
def __init__(self,name,job_queue,result_queue,host,username,password,options):
Worker.__init__(self,name,job_queue,result_queue,options)
self.host = host
self.username = username
self.password = password
def run(self):
import telnetlib
self.tn = tn = telnetlib.Telnet(self.host)
tn.read_until('login: ')
tn.write(self.username + '\n')
tn.read_until('Password: ')
tn.write(self.password + '\n')
# XXX: how to know whether login is successful?
tn.read_until(self.username)
#
print('login ok', self.host)
tn.write('cd '+os.getcwd()+'\n')
Worker.run(self)
tn.write('exit\n')
def run_one(self,c,g):
cmdline = self.get_cmd(c,g)
result = self.tn.write(cmdline+'\n')
(idx,matchm,output) = self.tn.expect(['Cross.*\n'])
for line in output.split('\n'):
if str(line).find('Cross') != -1:
return float(line.split()[-1][0:-1])
总体而言,这段代码实现了在远程主机上通过 Telnet 运行 SVM 训练任务的逻辑。它通过 Telnet 协议连接远程主机,执行相应的命令,并从输出中提取包含交叉验证结果的行,最终返回性能率作为结果。需要注意的是,代码中对登录成功的判断逻辑可能需要进一步完善。
这段代码实现了对 SVM 模型参数的并行搜索和优化,通过多线程/进程执行不同参数组合的训练 任务,然后比较性能,最终找到最佳的参数组合。
用于参数搜索和优化的部分,具体来说,它使用了多线程/进程的方式来执行 SVM 参数的搜索工作
def find_parameters(dataset_pathname, options=''):
def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
if (rate > best_rate) or (rate==best_rate and g==best_g and c
def find_parameters(dataset_pathname, options=''):
find_parameters
的函数,用于寻找 SVM 模型的最佳参数def update_param(c, g, rate, best_c, best_g, best_rate, worker, resumed):
update_param
,用于更新最佳参数和最佳性能率
options = GridOption(dataset_pathname, options);
if options.gnuplot_pathname:
gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
else:
gnuplot = None
options = GridOption(dataset_pathname, options);
:
GridOption
类处理参数选项,GridOption
类是对参数进行解析和处理的一个自定义类if options.gnuplot_pathname:
gnuplot
路径,如果提供了,则创建一个与 gnuplot
进程进行通信的管道 # put jobs in queue
jobs,resumed_jobs = calculate_jobs(options)
job_queue = Queue(0)
result_queue = Queue(0)
for (c,g) in resumed_jobs:
result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
for line in jobs:
for (c,g) in line:
if (c,g) not in resumed_jobs:
job_queue.put((c,g))
# hack the queue to become a stack --
# this is important when some thread
# failed and re-put a job. It we still
# use FIFO, the job will be put
# into the end of the queue, and the graph
# will only be updated in the end
job_queue._put = job_queue.queue.appendleft
jobs, resumed_jobs = calculate_jobs(options)
:
调用 calculate_jobs
函数,生成需要执行的任务列表 jobs
和已经恢复的任务列表 resumed_jobs
job_queue = Queue(0)
和 result_queue = Queue(0):
创建两个队列,job_queue
用于存放待执行的任务,result_queue
用于存放执行结果
for (c, g) in resumed_jobs:
和 for line in jobs:
job_queue._put = job_queue.queue.appendleft
:
将 job_queue
的 _put
方法指向 appendleft
方法,将队列变成一个栈,以确保重新放入的任务在队列头部
# fire telnet workers
if telnet_workers:
nr_telnet_worker = len(telnet_workers)
username = getpass.getuser()
password = getpass.getpass()
for host in telnet_workers:
worker = TelnetWorker(host,job_queue,result_queue,
host,username,password,options)
worker.start()
# fire ssh workers
if ssh_workers:
for host in ssh_workers:
worker = SSHWorker(host,job_queue,result_queue,host,options)
worker.start()
# fire local workers
for i in range(nr_local_worker):
worker = LocalWorker('local',job_queue,result_queue,options)
worker.start()
# gather results
done_jobs = {}
if options.out_pathname:
if options.resume_pathname:
result_file = open(options.out_pathname, 'a')
else:
result_file = open(options.out_pathname, 'w')
if telnet_workers:
和 if ssh_workers:
:
根据是否提供了 Telnet 或 SSH 主机列表,启动相应的 TelnetWorker 或 SSHWorker
for i in range(nr_local_worker):
启动本地工作线程,数量由 nr_local_worker
决定
done_jobs = {}: 用于存放已完成的任务及其结果
if options.out_pathname:
如果提供了输出路径,则打开一个文件用于记录结果
db = []
best_rate = -1
best_c,best_g = None,None
for (c,g) in resumed_jobs:
rate = resumed_jobs[(c,g)]
best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
for line in jobs:
for (c,g) in line:
while (c,g) not in done_jobs:
(worker,c1,g1,rate1) = result_queue.get()
done_jobs[(c1,g1)] = rate1
if (c1,g1) not in resumed_jobs:
best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
db.append((c,g,done_jobs[(c,g)]))
if gnuplot and options.grid_with_c and options.grid_with_g:
redraw(db,[best_c, best_g, best_rate],gnuplot,options)
redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
db = []
和 best_rate = -1
:用于存放任务执行结果的数据库和记录最佳性能率的变量
for (c, g) in resumed_jobs:
遍历已恢复的任务,更新最佳参数和最佳性能率
for line in jobs:
遍历待执行的任务
while (c, g) not in done_jobs:
循环等待任务执行完成,并将执行结果放入 done_jobs
(worker, c1, g1, rate1) = result_queue.get()
:从结果队列中获取执行结果
db.append((c, g, done_jobs[(c, g)])):
将任务执行结果加入数据库
if gnuplot and options.grid_with_c and options.grid_with_g:
如果提供了 gnuplot
路径,并且需要绘制图形,则调用 redraw
函数绘制图形
if options.out_pathname:
result_file.close()
job_queue.put((WorkerStopToken,None))
best_param, best_cg = {}, []
if best_c != None:
best_param['c'] = 2.0**best_c
best_cg += [2.0**best_c]
if best_g != None:
best_param['g'] = 2.0**best_g
best_cg += [2.0**best_g]
print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
return best_rate, best_param
if options.out_pathname:
job_queue.put((WorkerStopToken, None))
:
best_param, best_cg = {}, []
和 print('{0} {1}'.format(' '.join(map(str, best_cg)), best_rate))
:
return best_rate, best_param
: 返回最佳性能率和最佳参数
这是一个命令行工具的入口,用于解析命令行参数并调用 find_parameters
函数进行参数搜索
if __name__ == '__main__':
def exit_with_help():
print("""\
Usage: grid.py [grid_options] [svm_options] dataset
grid_options :
-log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
"null" -- do not grid with c
-log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
"null" -- do not grid with g
-v n : n-fold cross validation (default 5)
-svmtrain pathname : set svm executable path and name
-gnuplot {pathname | "null"} :
pathname -- set gnuplot executable path and name
"null" -- do not plot
-out {pathname | "null"} : (default dataset.out)
pathname -- set output file path and name
"null" -- do not output file
-png pathname : set graphic output file path and name (default dataset.png)
-resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
This is experimental. Try this option only if some parameters have been checked for the SAME data.
svm_options : additional options for svm-train""")
sys.exit(1)
if len(sys.argv) < 2:
exit_with_help()
dataset_pathname = sys.argv[-1]
options = sys.argv[1:-1]
try:
find_parameters(dataset_pathname, options)
except (IOError,ValueError) as e:
sys.stderr.write(str(e) + '\n')
sys.stderr.write('Try "grid.py" for more information.\n')
sys.exit(1)
if __name__ == '__main__':
def exit_with_help():
exit_with_help
,用于打印使用帮助信息并退出程序print(
'' '' ''\
…'' '' '')和 sys.exit(1)
:
sys.exit(1)
退出程序 if len(sys.argv) < 2:
和 exit_with_help()
:
如果命令行参数数量小于 2,则调用 exit_with_help
函数打印使用帮助信息并退出程序
dataset_pathname = sys.argv[-1]
和 options = sys.argv[1:-1]:
dataset_pathname
,将除第一个参数和最后一个参数外的其他参数赋值给 options
try: ... except (IOError, ValueError) as e: ...
:
try...except
结构捕获可能发生的 IOError
和 ValueError
异常try
块中调用 find_parameters
函数,传入数据集路径和其他参数总体而言,这段代码实现了一个命令行工具的入口,用于解析命令行参数并调用 find_parameters
函数进行参数搜索。如果命令行参数不符合要求或者执行过程中出现异常,将打印使用帮助信息或错误信息,并退出程序。