libsvm 学习笔记(二)--- easy.py 脚本详解

下面是个人对 easy.py 中源码的理解,其中的错误和不足恳请各位大神们不吝赐教,谢谢!


easy.py 按照以下顺序进行 svm 分类器的训练和分类

1)缩放训练数据

2)参数择优:(C,g)

3)训练svm分类器

4)缩放测试数据

5)分类



#!/usr/bin/env python



import sys
import os
from subprocess import *

# easy.py 要求至少有一个传入的参数:训练数据文本,否则给出提示并退出

# sys.argv[0] 是可执行程序名

sys.argv[1]...sys.argv[n] 是传入的参数

if len(sys.argv) <= 1:
    print('Usage: {0} training_file [testing_file]'.format(sys.argv[0]))
    raise SystemExit


# svm, grid, and gnuplot executable files

# 判断运行环境是不是windows平台

is_win32 = (sys.platform == 'win32')

# 非windows平台

if not is_win32:
    svmscale_exe = "../svm-scale"
    svmtrain_exe = "../svm-train"
    svmpredict_exe = "../svm-predict"
    grid_py = "./grid.py"
    gnuplot_exe = "/usr/bin/gnuplot"

# windows平台

else:

        # example for windows

    # 指定工具 svmscale、svmtrain、svmpredict、gnuplot、grid_py 的路径

    svmscale_exe = r"..\windows\svm-scale.exe"
    svmtrain_exe = r"..\windows\svm-train.exe"

    svmpredict_exe = r"..\windows\svm-predict.exe"

    # 需要根据自己的安装路径进行修改

    gnuplot_exe = r"c:\Program Files (x86)\gnuplot\bin\pgnuplot.exe"
    grid_py = r".\grid.py"


# 判断上述5个工具是否存在

assert os.path.exists(svmscale_exe),"svm-scale executable not found"
assert os.path.exists(svmtrain_exe),"svm-train executable not found"
assert os.path.exists(svmpredict_exe),"svm-predict executable not found"
assert os.path.exists(gnuplot_exe),"gnuplot executable not found"
assert os.path.exists(grid_py),"grid.py not found"

# 训练数据文本的路径

train_pathname = sys.argv[1]

# 判断训练数据文本是否存在

assert os.path.exists(train_pathname),"training file not found"

# 解析出训练数据文本的文件名

# 例如 os.path.split('d:\svm\libsvm-3.1\tools\train.txt') 结果:('d:\svm\libsvm-3.1\tools', 'train.txt')

file_name = os.path.split(train_pathname)[1]

# 缩放后的训练数据文本

scaled_file = file_name + ".scale"

# 训练出来的模型(记录模型参数,以备下一次直接预测)

model_file = file_name + ".model"

# 记录训练数据中x、y的范围,以备后面缩放测试数据使用

range_file = file_name + ".range"


# 若用户给出了测试数据文本

if len(sys.argv) > 2:

   # 解析出测试数据文本的文件名并判断文件是否存在

    test_pathname = sys.argv[2]
    file_name = os.path.split(test_pathname)[1]

    assert os.path.exists(test_pathname),"testing file not found"

    缩放后的测试数据文本

    scaled_test_file = file_name + ".scale"

    # svm分类器给出的预测结果

    predict_test_file = file_name + ".predict"

# 关键步骤1:缩放训练数据

# 这里是 python 脚本调用exe程序的方法,不用搞得特别明白,关心cmd这句命令就行:

# svm-scale.exe -s "range_file" "train_pathname" > "scaled_file" 

# 意义:使用程序 svm-scale.exe 缩放训练数据 train_pathname ,并把缩放结果保存在scaled_file中;

-srange_file 表示将训练数据中x、y的范围保存在文件range_file中

cmd = '{0} -s "{1}" "{2}" > "{3}"'.format(svmscale_exe, range_file, train_pathname, scaled_file)
print('Scaling training data...')
Popen(cmd, shell = True, stdout = PIPE).communicate()


关键步骤2:参数择优:(C,g)

cmd = '{0} -svmtrain "{1}" -gnuplot "{2}" "{3}"'.format(grid_py, svmtrain_exe, gnuplot_exe, scaled_file)
print('Cross validation...')
f = Popen(cmd, shell = True, stdout = PIPE).stdout


# 参数择优的过程会将每个C,g)对及准确率记录在文件 f 中,f 的最后一行记录了最优的参数(C,g)和相应的准确率

line = ''
while True:
    last_line = line
    line = f.readline()
    if not line: break
c,g,rate = map(float,last_line.split())


# 打印最优的参数(C,g)和相应的准确率
print('Best c={0}, g={1} CV rate={2}'.format(c,g,rate))


关键步骤3:训练svm分类器

# 使用上面计算出来的最优参数c和g,“-c value_c”指定 c 值,-g value_g”指定 g 值

# scaled_file 指定训练数据(缩放后),model_file 记录训练后的模型参数

cmd = '{0} -c {1} -g {2} "{3}" "{4}"'.format(svmtrain_exe,c,g,scaled_file,model_file)
print('Training...')
Popen(cmd, shell = True, stdout = PIPE).communicate()


# 输出模型参数

print('Output model: {0}'.format(model_file))


# 若用户给出了测试数据

if len(sys.argv) > 2:

    # 关键步骤4:缩放测试数据

    # -r "range_file"表示之前训练数据中x、y的范围,测试数据将按照这个范围进行缩放,即保证训练数据和测试数据的缩放的标准统一

    cmd = '{0} -r "{1}" "{2}" > "{3}"'.format(svmscale_exe, range_file, test_pathname, scaled_test_file)
    print('Scaling testing data...')
    Popen(cmd, shell = True, stdout = PIPE).communicate()

     # 关键步骤5:使用工具 svmpredict_exe 和模型 model_file 来为测试数据 scaled_test_file 进行分类,结果保存在 predict_test_file 中
    cmd = '{0} "{1}" "{2}" "{3}"'.format(svmpredict_exe, scaled_test_file, model_file, predict_test_file)
    print('Testing...')
    Popen(cmd, shell = True).communicate()

   # 输出分类结果

    print('Output prediction: {0}'.format(predict_test_file))



上述的5个关键步骤中还有很多可选的参数有待研究,以后会陆续对这些内容进行总结。

你可能感兴趣的:(机器学习&数据挖掘算法)