下面是个人对 easy.py 中源码的理解,其中的错误和不足恳请各位大神们不吝赐教,谢谢!
easy.py 按照以下顺序进行 svm 分类器的训练和分类:
1)缩放训练数据
2)参数择优:(C,g)
3)训练svm分类器
4)缩放测试数据
5)分类
#!/usr/bin/env python
# easy.py 要求至少有一个传入的参数:训练数据文本,否则给出提示并退出
# sys.argv[0] 是可执行程序名
# sys.argv[1]...sys.argv[n] 是传入的参数
is_win32 = (sys.platform == 'win32')
# 非windows平台
if not is_win32:# windows平台
else:
# example for windows
# 指定工具 svmscale、svmtrain、svmpredict、gnuplot、grid_py 的路径
svmscale_exe = r"..\windows\svm-scale.exe"svmpredict_exe = r"..\windows\svm-predict.exe"
# 需要根据自己的安装路径进行修改
gnuplot_exe = r"c:\Program Files (x86)\gnuplot\bin\pgnuplot.exe"# 判断上述5个工具是否存在
train_pathname = sys.argv[1]
# 判断训练数据文本是否存在
assert os.path.exists(train_pathname),"training file not found"
# 解析出训练数据文本的文件名
# 例如 os.path.split('d:\svm\libsvm-3.1\tools\train.txt') 结果:('d:\svm\libsvm-3.1\tools', 'train.txt')
file_name = os.path.split(train_pathname)[1]
# 缩放后的训练数据文本
scaled_file = file_name + ".scale"
# 训练出来的模型(记录模型参数,以备下一次直接预测)
model_file = file_name + ".model"
# 记录训练数据中x、y的范围,以备后面缩放测试数据使用
# 若用户给出了测试数据文本
if len(sys.argv) > 2:
# 解析出测试数据文本的文件名并判断文件是否存在
assert os.path.exists(test_pathname),"testing file not found"
# 缩放后的测试数据文本
scaled_test_file = file_name + ".scale"
# svm分类器给出的预测结果
predict_test_file = file_name + ".predict"# 关键步骤1:缩放训练数据
# 这里是 python 脚本调用exe程序的方法,不用搞得特别明白,关心cmd这句命令就行:
# svm-scale.exe -s "range_file" "train_pathname" > "scaled_file"
# 意义:使用程序 svm-scale.exe 缩放训练数据 train_pathname ,并把缩放结果保存在scaled_file中;
# -srange_file 表示将训练数据中x、y的范围保存在文件range_file中
cmd = '{0} -s "{1}" "{2}" > "{3}"'.format(svmscale_exe, range_file, train_pathname, scaled_file)# 关键步骤2:参数择优:(C,g)
cmd = '{0} -svmtrain "{1}" -gnuplot "{2}" "{3}"'.format(grid_py, svmtrain_exe, gnuplot_exe, scaled_file)# 参数择优的过程会将每个(C,g)对及准确率记录在文件 f 中,f 的最后一行记录了最优的参数(C,g)和相应的准确率
line = ''# 关键步骤3:训练svm分类器
# 使用上面计算出来的最优参数c和g,“-c value_c”指定 c 值,“-g value_g”指定 g 值
# scaled_file 指定训练数据(缩放后),model_file 记录训练后的模型参数
cmd = '{0} -c {1} -g {2} "{3}" "{4}"'.format(svmtrain_exe,c,g,scaled_file,model_file)print('Output model: {0}'.format(model_file))
# 若用户给出了测试数据
if len(sys.argv) > 2:
# 关键步骤4:缩放测试数据
# -r "range_file"表示之前训练数据中x、y的范围,测试数据将按照这个范围进行缩放,即保证训练数据和测试数据的缩放的标准统一
print('Output prediction: {0}'.format(predict_test_file))
上述的5个关键步骤中还有很多可选的参数有待研究,以后会陆续对这些内容进行总结。