libsvm 学习笔记(四)--- grid.py 关键代码详解

在所有可取的值对下,grid.py 都训练出一个分类器,并计算该分类器的准确率,将准确率最高的分类器所对应的作为最优的参数。这是通过交叉验证来实现的:将训练集分为 N 折,其中的 N-1 折用来训练分类器,剩下的 1 折用来计算分类准确率,从而计算 N 次实验的平均准确率作为给定值对下的分类器的准确率。


----------------------------------------------------------------------------------------------------------------------------------------

全局设置:


# svmtrain and gnuplot executable


is_win32 = (sys.platform == 'win32')
if not is_win32:
       svmtrain_exe = "../svm-train"
       gnuplot_exe = "/usr/bin/gnuplot"
else:
       # example for windows
       svmtrain_exe = r"..\windows\svm-train.exe"

       # windows平台需要修改程序 gnuplot 的路径
       gnuplot_exe = r"c:\Program Files (x86)\gnuplot\bin\pgnuplot.exe"


# global parameters and their default values

# 全局变量,同时设置默认的参数

# 5-折交叉验证
fold = 5

# 参数 c和g 的范围 begin~end 以及步长 step
c_begin, c_end, c_step = -5,  15, 2
g_begin, g_end, g_step =  3, -15, -2

dataset_pathname:训练数据路径,dataset_title:训练数据文件名,pass_through_string:用户给定的 svm 分类器的参数
global dataset_pathname, dataset_title, pass_through_string
out_filename:文件名,该文件记录交叉验证过程中计算得到的每组数据(c,g,rate等),png_filename:图片的文件名,将 out_filename 中的数据画成二维图形

global out_filename, png_filename


# experimental

# 设置 telnet 和 ssh 的节点名称--- grid.py 的网格搜索可以并行化
telnet_workers = []
ssh_workers = []

# 本地节点数量
nr_local_worker = 1


----------------------------------------------------------------------------------------------------------------------------------------

处理命令行传入的参数:



# process command line options, set global parameters
def process_options(argv=sys.argv):

    # 将这些变量设为全局的,免去函数间的传参
    global fold
    global c_begin, c_end, c_step
    global g_begin, g_end, g_step
    global dataset_pathname, dataset_title, pass_through_string

    gnuplot 是画图的句柄,要随时将得到的每组数据(c,g,rate等)画在图上
    global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
    

    # 命令提示
    usage = """\
Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] 
[-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
[additional parameters for svm-train] dataset"""

    # 命令使用错误提示
    if len(argv) < 2:
        print(usage)
        sys.exit(1)

    # 变量赋值
    dataset_pathname = argv[-1]
    dataset_title = os.path.split(dataset_pathname)[1]
    out_filename = '{0}.out'.format(dataset_title)
    png_filename = '{0}.png'.format(dataset_title)

    # 传给 svmtrain_exe 的参数列表
    pass_through_options = []


    i = 1
    while i < len(argv) - 1:

        # 用户给定的 关于c 的3个参数:范围 begin~end 以及步长 step
        if argv[i] == "-log2c":

            i = i + 1
            (c_begin,c_end,c_step) = map(float,argv[i].split(","))

        # 用户给定的 关于g 的3个参数:范围 begin~end 以及步长 step
        elif argv[i] == "-log2g":
            i = i + 1
            (g_begin,g_end,g_step) = map(float,argv[i].split(","))

        # 用户给定的交叉验证的折数
        elif argv[i] == "-v":
            i = i + 1
            fold = argv[i]

        # 参数传递方法有误,给出提示
        elif argv[i] in ('-c','-g'):
            print("Option -c and -g are renamed.")
            print(usage)
            sys.exit(1)

        # 用户给定的 svmtrain_exe 的路径(使用svmtrain_exe 来计算某个对下的准确率
        elif argv[i] == '-svmtrain':
            i = i + 1
            svmtrain_exe = argv[i]

        # 用户给定的 gnuplot_exe 的路径
        elif argv[i] == '-gnuplot':
            i = i + 1
            gnuplot_exe = argv[i]

        # 用户指定记录结果的文件名
        elif argv[i] == '-out':
            i = i + 1
            out_filename = argv[i]

        # 用户指定记录结果的图片文件名
        elif argv[i] == '-png':
            i = i + 1
            png_filename = argv[i]

        # 用户给定的 svmtrain_exe 的参数
        else:
            pass_through_options.append(argv[i])
        i = i + 1


    # svmtrain_exe 以字符串的形式接收参数
    pass_through_string = " ".join(pass_through_options)

    # 检查程序svmtrain_exe 、gnuplot_exe  和训练数据文件dataset_pathname 是否存在
    assert os.path.exists(svmtrain_exe),"svm-train executable not found"    
    assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
    assert os.path.exists(dataset_pathname),"dataset not found"

    # 初始化画图句柄 gnuplot(python 连接、调用 gnuplot_exe 程序 
    gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdin


----------------------------------------------------------------------------------------------------------------------------------------

产生需要计算的(c,g)对:


函数 calculate_jobs 调用函数range_f 和 permute_sequence,将 c 和 g 的每一对可取的值组合起来,例如,根据默认值,c 可取 [-5, -3, -1, 1, 3, 5, 7, 9, 11, 13, 15] ,g 可取 [3, 1, -1, -3, -5, -7, -9, -11, -13, -15],则产生 110个(c,g)对:

(5, -7)
(-1, -7)
(5, -1) (-1, -1)
(11, -7) (11, -1)
(5, -13) (-1, -13) (11, -13)
(-3, -7) (-3, -1) (-3, -13)
(5, 1) (-1, 1) (11, 1) (-3, 1)
(9, -7) (9, -1) (9, -13) (9, 1)
(5, -11) (-1, -11) (11, -11) (-3, -11) (9, -11)
(3, -7) (3, -1) (3, -13) (3, 1) (3, -11)
(5, -5) (-1, -5) (11, -5) (-3, -5) (9, -5) (3, -5)
(15, -7) (15, -1) (15, -13) (15, 1) (15, -11) (15, -5)
(5, -15) (-1, -15) (11, -15) (-3, -15) (9, -15) (3, -15) (15, -15)
(-5, -7) (-5, -1) (-5, -13) (-5, 1) (-5, -11) (-5, -5) (-5, -15)
(5, 3) (-1, 3) (11, 3) (-3, 3) (9, 3) (3, 3) (15, 3) (-5, 3)
(7, -7) (7, -1) (7, -13) (7, 1) (7, -11) (7, -5) (7, -15) (7, 3)
(5, -9) (-1, -9) (11, -9) (-3, -9) (9, -9) (3, -9) (15, -9) (-5, -9) (7, -9)
(1, -7) (1, -1) (1, -13) (1, 1) (1, -11) (1, -5) (1, -15) (1, 3) (1, -9)
(5, -3) (-1, -3) (11, -3) (-3, -3) (9, -3) (3, -3) (15, -3) (-5, -3) (7, -3) (1, -3)
(13, -7) (13, -1) (13, -13) (13, 1) (13, -11) (13, -5) (13, -15) (13, 3) (13, -9) (13, -3)


grid.py 计算这 110 个值对下的分类准确率,从而选择准确率最高的值作为最优参数。


def range_f(begin,end,step):
    # like range, but works on non-integer too
    seq = []
    while True:
        if step > 0 and begin > end: break
        if step < 0 and begin < end: break
        seq.append(begin)
        begin = begin + step
    return seq


def permute_sequence(seq):
    n = len(seq)
    if n <= 1: return seq


    mid = int(n/2)
    left = permute_sequence(seq[:mid])
    right = permute_sequence(seq[mid+1:])


    ret = [seq[mid]]
    while left or right:
        if left: ret.append(left.pop(0))
        if right: ret.append(right.pop(0))


    return ret




def calculate_jobs():
    c_seq = permute_sequence(range_f(c_begin,c_end,c_step))
    g_seq = permute_sequence(range_f(g_begin,g_end,g_step))
    nr_c = float(len(c_seq))
    nr_g = float(len(g_seq))
    i = 0
    j = 0
    jobs = []


    while i < nr_c or j < nr_g:
        if i/nr_c < j/nr_g:
            # increase C resolution
            line = []
            for k in range(0,j):
                line.append((c_seq[i],g_seq[k]))
            i = i + 1
            jobs.append(line)
        else:
            # increase g resolution
            line = []
            for k in range(0,i):
                line.append((c_seq[k],g_seq[j]))
            j = j + 1
            jobs.append(line)
    return jobs


----------------------------------------------------------------------------------------------------------------------------------------

根据给定的(c,g)对,执行交叉验证,计算准确率rate:



class Worker(Thread):

    # 线程初始化
    def __init__(self,name,job_queue,result_queue):
        Thread.__init__(self)

        # 工作节点名称
        self.name = name

        # 分配给当前工作节点的任务
        self.job_queue = job_queue

        当前工作节点的计算结果表
        self.result_queue = result_queue


    # 启动线程( main函数中的LocalWorker(...).start()就是调用的这个函数 
    def run(self):
        while True:

            # 从自己的工作队列中取出一个任务
            (cexp,gexp) = self.job_queue.get()

            # 工作节点被停止了 
            if cexp is WorkerStopToken:
                self.job_queue.put((cexp,gexp))
                # print('worker {0} stop.'.format(self.name))
                break
            try:

                # 调用 run_one 函数,计算准确率
                rate = self.run_one(2.0**cexp,2.0**gexp)

                # 计算不成功
                if rate is None: raise RuntimeError("get no rate")
            except:
                # we failed, let others do that and we just quit
            
                traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
                
                self.job_queue.put((cexp,gexp))
                print('worker {0} quit.'.format(self.name))
                break

            # 计算成功,将得到的结果放入结果表 result_queue 中
            else:
                self.result_queue.put((self.name,cexp,gexp,rate))


class LocalWorker(Worker):
    def run_one(self,c,g):

         # 调用程序 svmtrain_exe 的命令
        cmdline = '{0} -c {1} -g {2} -v {3} {4} {5}'.format \
          (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)

        # 调用程序 svmtrain_exe
        result = Popen(cmdline,shell=True,stdout=PIPE).stdout

        # 处理计算结果 line,字符串 line 中的字段以空格分离,最后一个字段为准确率,函数返回这个数值
        for line in result.readlines():
            if str(line).find("Cross") != -1:
                return float(line.split()[-1][0:-1])


----------------------------------------------------------------------------------------------------------------------------------------

主函数:



def main():


    # set parameters

    # 处理命令行参数
    process_options()


    # put jobs in queue

    # 计算任务(所有 c 和 g 的组合)
    jobs = calculate_jobs()

    # 任务队列和结果队列
    job_queue = Queue.Queue(0)
    result_queue = Queue.Queue(0)

    # 将任务放入任务队列中
    for line in jobs:
        for (c,g) in line:
            job_queue.put((c,g))

   # 意思大概为:将队列转换为栈,原因是:如果某个节点计算出现问题,那么可以再将未计算成功的任务(对放入栈中,下一个取出的还是该任务,如果是“先进先出”的队列,就无法保证未计算成功的任务立即再被执行
    # hack the queue to become a stack --
    # this is important when some thread
    # failed and re-put a job. It we still
    # use FIFO, the job will be put
    # into the end of the queue, and the graph
    # will only be updated in the end
 
    job_queue._put = job_queue.queue.appendleft




    # fire telnet workers

    # 如果有 telnet 节点
    if telnet_workers:
        nr_telnet_worker = len(telnet_workers)
        username = getpass.getuser()
        password = getpass.getpass()
        for host in telnet_workers:
            TelnetWorker(host,job_queue,result_queue,
                     host,username,password).start()


    # fire ssh workers


    # 如果有 ssh节点
    if ssh_workers:
        for host in ssh_workers:
            SSHWorker(host,job_queue,result_queue,host).start()


    # fire local workers


    # 启动本地节点线程
    for i in range(nr_local_worker):
        LocalWorker('local',job_queue,result_queue).start()


    # gather results


    # 记录处理过的计算结果
    done_jobs = {}

    # 结果文件
    result_file = open(out_filename, 'w')

    # 传递给 gnuplot_exe  程序的数据
    db = []

    # 最佳准确率以及最佳准确率下的c和g值
    best_rate = -1
    best_c1,best_g1 = None,None


    # 遍历务栈
    for line in jobs:
        for (c,g) in line:

            # 若(c, g) 的结果还未处理
            while (c, g) not in done_jobs:

                # 取出一个计算结果
                (worker,c1,g1,rate) = result_queue.get()

                # 处理计算结果
                done_jobs[(c1,g1)] = rate

                # 写入结果文件
                result_file.write('{0} {1} {2}\n'.format(c1,g1,rate))
                result_file.flush()

                # 若 准确率更高 或者 相同准确率下具有更小的c值,则对最佳rate和相应的c、g进行更新
                if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1                     best_rate = rate

                    # log2(best_c)=best_c1, log2(best_g)=best_g1
                    best_c1,best_g1=c1,g1
                    best_c = 2.0**c1
                    best_g = 2.0**g1

                # 打印计算结果
                print("[{0}] {1} {2} {3} (best c={4}, g={5}, rate={6})".format \
            (worker,c1,g1,rate, best_c, best_g, best_rate))

            # 记录计算结果
            db.append((c,g,done_jobs[(c,g)]))

        # 画图
        redraw(db,[best_c1, best_g1, best_rate])
        redraw(db,[best_c1, best_g1, best_rate],True)

    # 停止节点的工作
    job_queue.put((WorkerStopToken,None))

    # 打印最佳准确率下的c、g值以及最佳准确率
    print("{0} {1} {2}".format(best_c, best_g, best_rate))

你可能感兴趣的:(机器学习&数据挖掘算法)