本文是基于matlab平台的libsvm工具箱进行的,是羊同学的练手做,代码编写不太仔细,欢迎大家斧正。
Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
由于本数据集十分的典型,羊同学便采用了本数据集。
我们通过以下两条命令简单的观察一下该数据集的形式:
load Iris
head Iris
得到结果如下:
ans =
8×6 table
Id SepalLengthCm SepalWidthCm PetalLengthCm PetalWidthCm Species
__ _____________ ____________ _____________ ____________ ___________
1 5.1 3.5 1.4 0.2 Iris-setosa
2 4.9 3 1.4 0.2 Iris-setosa
3 4.7 3.2 1.3 0.2 Iris-setosa
4 4.6 3.1 1.5 0.2 Iris-setosa
5 5 3.6 1.4 0.2 Iris-setosa
6 5.4 3.9 1.7 0.4 Iris-setosa
7 4.6 3.4 1.4 0.3 Iris-setosa
8 5 3.4 1.5 0.2 Iris-setosa
为了方便,我们将标签数据化处理。我们可通过unique()
函数查看类别数。
unique(Iris.Species)
可以得到结果:
ans =
3×1 categorical 数组
Iris-setosa
Iris-versicolor
Iris-virginica
所以羊同学随手把三种类别转换为了1、2、3。
%% 将类别数据化
num = length(Iris.Species);
label = [];
for i = 1:num
switch Iris.Species(i)
case 'Iris-setosa'
label(i) = 1;
case 'Iris-versicolor'
label(i) = 2;
case 'Iris-virginica '
label(i) = 3;
end
end
label = label';
data = [Iris.SepalLengthCm,Iris.SepalWidthCm,Iris.PetalLengthCm,Iris.PetalWidthCm];
羊同学决定对数据进行简单的可视化分析,对数据的情况有个整体的把握,主要从以下几个方面进行:
RHO = corr(data);
name = Iris.Properties.VariableNames(2:5);
heatmap(name,name,RHO);
colormap hot
其结果如图:
由此可发现PetalLengthCm和PetalWidthCm相关性较强,我们可以删除一组,也可以用PCA提取主成分消除独立性。不过由于羊同学很懒,就直接不做了,有兴趣的同学可以试试看,效果会不会变好。
boxplot(data,'orientation','horizontal','labels',name);
one2six = 1:4;
comb = combntns(one2six,2);
index_1 = find(label==1);
index_2 = find(label==2);
index_3 = find(label==3);
figure
hold on
for i = 1:6
subplot(2,3,i)
scatter(data(index_1,comb(i,1)),index_1,data(comb(i,2)),'fill','r');
hold on
scatter(data(index_2,comb(i,1)),index_2,data(comb(i,2)),'fill','g');
hold on
scatter(data(index_3,comb(i,1)),index_3,data(comb(i,2)),'fill','b');
title([name{comb(i,1)},' and ',name{comb(i,2)}]);
legend('Iris-setosa','Iris-versicolor','Iris-virginica','location','best');
end
hold off
羊同学随手划分了一下,由于样本数不多感兴趣的同学也可以采用k-fold crossValidation方法。
train_data = [data(1:40,:);data(51:90,:);data(101:140,:)];
train_label = [label(1:40,:);label(51:90,:);label(101:140,:)];
test_data = [data(1:40,:);data(51:90,:);data(101:140,:)];
test_label = [label(1:40,:);label(51:90,:);label(101:140,:)];
使用mapminmax函数时,注意它的归一化方法。在归一化前记得转置。
[mtrain,ntrain] = size(train_data);
[mtest,ntest] = size(test_data);
dataset = [train_data;test_data]; % mapminmax为MATLAB自带的归一化函数
[dataset_scale,ps] = mapminmax(dataset',0,1); %归一化要先转至
dataset_scale = dataset_scale';
train_data = dataset_scale(1:mtrain,:);
test_data = dataset_scale( (mtrain+1):(mtrain+mtest),: );
羊同学先不调整任何参数,直接使用默认参数进行训练,看看效果:
%% SVM网络训练
model = svmtrain(train_label, train_data,'-c 2 -g 1');
%% SVM网络预测
[predict_label, accuracy,desc_value] = svmpredict(test_label, test_data, model); % desc_value !!
注意:libsvm工具箱由于版本不同会有一定的不同,新版本的预测函数需要添加dec_value
,不然会运算报错。
optimization finished, #iter = 30
nu = 0.410476
obj = -48.386381, rho = 0.160529
nSV = 34, nBSV = 31
Total nSV = 43
Accuracy = 96.6667% (116/120) (classification)
由此我们可以看到正确率达到了96.6667%,有种勉勉强强的感觉呢。
figure;
hold on;
plot(test_label,'o');
plot(predict_label,'r*');
xlabel('测试集样本','FontSize',12);
ylabel('类别标签','FontSize',12);
legend('实际测试集分类','预测测试集分类');
title('测试集的实际分类和预测分类图','FontSize',12);
grid on;
结果如图:
通过图可以看出只有一个测试样本是被错分的。就这样简单的libsvm模型就完成了。当然还可以对其进行许多方面的优化,最重要的一点就是对于libsvm参数模型的优化。
Options:可用的选项即表示的涵义如下
option -v 随机地将数据剖分为n部分并计算交互检验准确度和均方根误差。以上这些参数设置可以按照SVM的类型和核函数所支持的参数进行任意组合,如果设置的参数在函数或SVM类型中没有也不会产生影响,程序不会接受该参数;如果应有的参数设置不正确,参数将采用默认值。
training_set_file是要进行训练的数据集;model_file是训练结束后产生的模型文件,文件中包括支持向量样本数、支持向量样本以及lagrange系数等必须的参数;该参数如果不设置将采用默认的文件名,也可以设置成自己惯用的文件名。
优化后的结果:
optimization finished, #iter = 40
nu = 0.518337
obj = -2741.351977, rho = 0.228981
nSV = 43, nBSV = 40
Total nSV = 52
Accuracy = 97.5% (117/120) (classification)
打印测试集分类准确率
Accuracy = 97.5% (117/120)
可视化作图:
模型得到了一定程度的优化。