看了西关书的聚类算法,算法原理很容易明白,接下来就是整理成自己的理解思路,然后一步一步来实现算法,那么就来做吧。
学习向量量化(LVQ)算法步骤
输入:样本集(带标记),原型向量个数,各原型向量预设的类别标记,学习率(0,1)
输出:最终的原型向量
Step1:初始化函数 ,init_lvq()
Step1.1: 载入样本集,初始化学习率,初始化原型向量
输入: txt样本数据,原型向量个数q
输出: 样本集,原型向量
Repeat
Step2:聚类: lvq_cluster()
输入:原型向量、样本
输出:最近原型向量
Step2.1: 计算每个样本与原型向量集的距离,并找到与xi最近的原型向量
Step2.2: 根据标签,保存到样本容器
Step3: 更新原型向量并,保存原型向量对应的样本容器: update_vector();
若标签相同,p’ = p+yiTa *(x-p)
否则:p’ = p-yiTa *(x-p)
Until
Setp4: 中止迭代条件: is_stop();
若P与p’的欧式距离<0.01 ,停止迭代
LVQ代码及样本下载链接
ok啦,接下来不废话上代码(Matlab发布形式)
%Algorithm: learning vector Quantization
%input: 带标记的样本集
%output:原型向量
function Main()
clc;clear;close
%1.0
melon_data = load('melon4.0_lables.txt');
q = 5;
learn_rate = 0.1;
[learn_vector,sample_data]= init_lvq(melon_data,q);
[sample_rows,~] = size(sample_data);
sample_calss_box =zeros(sample_rows,2*q); % store class
show_now(learn_vector,sample_data,sample_calss_box);
max_loop = 100;
for loop = 1:max_loop
temp_vector = learn_vector;
for i = 1:sample_rows
close_LV_ind = lvq_cluster(sample_data(i,1:2),learn_vector);
[learn_vector, sample_calss_box]= update_vector(i,close_LV_ind,sample_data,learn_vector,learn_rate,sample_calss_box);
end
if is_stop(learn_vector,temp_vector)
show_now(learn_vector,sample_data,sample_calss_box);
fprintf('迭代次数为:%d\n',loop);
break;
end
show_now(learn_vector,sample_data,sample_calss_box);
sample_calss_box =zeros(sample_rows,2*q); %clear
end
end
%1.0
function [learn_vector,sample_data]= init_lvq(melon_data,q)
sample_data = melon_data(:,2:end);
learn_vector =[sample_data(5,:);
sample_data(12,:);
sample_data(18,:);
sample_data(23,:);
sample_data(29,:)];
end
%2.0
function close_LV_ind = lvq_cluster(sample_data,learn_vector)
dist = pdist2(learn_vector(:,1:2),sample_data);
[~,close_LV_ind] = min(dist);
end
function [cur_vector,sample_box]= update_vector(i,close_LV_ind,sample_data,learn_vector,learn_rate,sample_calss_box)
if sample_data(i,end)==learn_vector(close_LV_ind,end)
learn_vector(close_LV_ind,1:2) = learn_vector(close_LV_ind,1:2) +...
learn_rate*(sample_data(i,1:2)-learn_vector(close_LV_ind,1:2));
else
learn_vector(close_LV_ind,1:2) = learn_vector(close_LV_ind,1:2) -...
learn_rate*(sample_data(i,1:2)-learn_vector(close_LV_ind,1:2));
end
same_label_ind = find(learn_vector(:,end)==sample_data(i,end));
if ~isempty(same_label_ind)
[~,min_BDI ]= min( pdist2(learn_vector(same_label_ind,1:2), sample_data(i,1:2)));%min_box_dist_ind
zero_ind = find(sample_calss_box(:,(same_label_ind(min_BDI)*2))==0);
sample_calss_box(zero_ind(1),(same_label_ind(min_BDI)*2-1):(same_label_ind(min_BDI)*2)) = sample_data(i,1:2);
end
cur_vector = learn_vector;
sample_box = sample_calss_box;
end
function isStop = is_stop(learn_vector,temp_vector)
isStop = 0;
cons_ind = find(pdist(learn_vector-temp_vector)>0.00000001);
if isempty(cons_ind)
isStop =1;
end
end
function show_now(learn_vector,sample_data,sample_calss_box)
close
label1 = [sample_data(1:8,:);sample_data(22:end,:)];
label2 = sample_data(9:21,:);
plot(label1(:,1),label1(:,2),'+b');hold on
plot(label2(:,1),label2(:,2),'+r'); %2
%vector
plot(learn_vector(1,1),learn_vector(1,2),'pb');
plot(learn_vector(4,1),learn_vector(4,2),'pb');
plot(learn_vector(5,1),learn_vector(5,2),'pb');
plot(learn_vector(2,1),learn_vector(2,2),'pr');%2
plot(learn_vector(3,1),learn_vector(3,2),'pr');
%sample box
plot(sample_calss_box(:,1),sample_calss_box(:,2),'ro');
plot(sample_calss_box(:,3),sample_calss_box(:,4),'bo');
plot(sample_calss_box(:,5),sample_calss_box(:,6),'go');
plot(sample_calss_box(:,7),sample_calss_box(:,8),'rs');
plot(sample_calss_box(:,9),sample_calss_box(:,10),'bs');
xlabel('density');ylabel('sugar rate');
end
迭代次数为:41