clear,clc
%% 二分类
%训练数据20×2,20行代表20个训练样本点,第一列代表横坐标,第二列纵坐标
Train_Data =[-3 0;4 0;4 -2;3 -3;-3 -2;1 -4;-3 -4;0 1;-1 0;2 2;3 3;-2 -1;-4.5 -4;2 -1;5 -4;-2 2;-2 -3;0 2;1 -2;2 0];
%Group 20 x 1,20行代表训练数据对应点属于哪一类(1类,-1类)
Train_labels =[1 -1 -1 -1 1 -1 1 1 1 -1 -1 1 1 -1 -1 1 1 1 -1 -1]';
TestData = [3 -1;3 1;-2 1;-1 -2;2 -3;-3 -3];%测试数据
classifier = fitcsvm(Train_Data,Train_labels); %train
test_labels = predict(classifier ,TestData); % test
这里 test_labels 就是最后的分类结果啦,大家可以按照这个格式对自己的数据进行修改
因为
%% 多分类
TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42];%训练数据
TestSet=[3 34; 1 14; 2.2 25; 6.2 63];%测试数据
GroupTrain=[1;1;2;2;3;3;2;2];%训练标签
results =my_MultiSvm(TrainingSet, GroupTrain, TestSet);
disp('multi class problem');
disp(results);
results为最终的分类结果,上述中有用到 my_MultiSvm.m() 函数,以下是my_MultiSvm.m函数的全部内容
function [y_predict,models] = my_MultiSvm(X_train, y_train, X_test)
% multi svm
% one vs all 模型
% Input:
% X_train: n*m矩阵 n为训练集样本数 m为特征数
% y_train: n*1向量 为训练集label,支持任意多种类
% X_test: n*m矩阵 n为测试集样本数 m为特征数
% Output:
% y_predict: n*1向量 测试集的预测结果
%
% Copyright(c) lihaoyang 2020
%
y_labels = unique(y_train);
n_class = size(y_labels, 1);
models = cell(n_class, 1);
% 训练n个模型
for i = 1:n_class
class_i_place = find(y_train == y_labels(i));
svm_train_x = X_train(class_i_place,:);
sample_num = numel(class_i_place);
class_others = find(y_train ~= y_labels(i));
randp = randperm(numel(class_others));
svm_train_minus = randp(1:sample_num)';
svm_train_x = [svm_train_x; X_train(svm_train_minus,:)];
svm_train_y = [ones(sample_num, 1); -1*ones(sample_num, 1)];
disp(['生成模型:', num2str(i)])
models{i} = fitcsvm(svm_train_x, svm_train_y);
end
test_num = size(X_test, 1);
y_predict = zeros(test_num, 1);
% 对每条数据,n个模型分别进行预测,选择label为1且概率最大的一个作为预测类别
for i = 1:test_num
if mod(i, 100) == 0
disp(['预测个数:', num2str(i)])
end
bagging = zeros(n_class, 1);
for j = 1:n_class
model = models{j};
[label, rat] = predict(model, X_test(i,:));
bagging(j) = bagging(j) + rat(2);
end
[maxn, maxp] = max(bagging);
y_predict(i) = y_labels(maxp);
end
end
以下代码是调用matlab工具箱libsvm的一种方法
TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42];%训练数据
TestSet=[3 34; 1 14; 2.2 25; 6.2 63];%测试数据
GroupTrain=[1;1;2;2;3;3;2;2];%训练标签
GroupTest=[1;2;1;3];%测试标签
%svm分类
model = svmtrain(GroupTrain,TrainingSet);
% SVM网络预测
[predict_label] = svmpredict(GroupTest,TestSet,model);
之所以放到最后,是因为需要在matlab安装libsvm的工具箱,具体方法可参看此链接在Matlab中安装LibSVM工具箱
下载libsvm也可以百度网盘:百度网盘libsvm
提取码:25ft