例子:手写识别数字(网上有很多资源可以下载)
1)主程序:
clear all;
close all;
global ImageRow ImageCol TrainNum TestNum;
ImageRow=28;
ImageCol=28;
TrainNum=2000;
TestNum=900;
k=8;
TrainData=LoadMNISTImages('train-images.idx3-ubyte');
TrainData=TrainData(:,1:TrainNum);
TrainLabel=LoadMNISTLabels('train-labels.idx1-ubyte');
TrainLabel=TrainLabel(1:TrainNum);
TestData=LoadMNISTImages('t10k-images.idx3-ubyte');
TestData=TestData(:,1:TestNum);
TestLabel=LoadMNISTLabels('t10k-labels.idx1-ubyte');
TestLabel=TestLabel(1:TestNum);
PredictLabel=knn(TrainData,TrainLabel,TestData,k);
accuracy=sum(PredictLabel==TestLabel)/TestNum;
disp(['准确率是:',num2str(accuracy*100),'%']);
2)k n n.m
function PredictLabel=knn(dataX,LabelX,dataY,k)
global ImageRow ImageCol TrainNum TestNum;
PredictLabel=zeros(TestNum,1);
for i=1:TestNum
differ=sqrt(sum(((dataX-repmat(dataY(:,i),1,TrainNum)).^2),1));%计算欧式距离
[p n]=sort(differ,'ascend');%距离从大到小排序
PredictLabel(i)=mode(LabelX(n(1:k)));%将k个最近邻中标签最多的类做为预测结果
end
end
3)LoadMNISTImages.m
function images = LoadMNISTImages(filename)
%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing
%the raw MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2051, ['Bad magic number in ', filename, '']);
numImages = fread(fp, 1, 'int32', 0, 'ieee-be');
numRows = fread(fp, 1, 'int32', 0, 'ieee-be');
numCols = fread(fp, 1, 'int32', 0, 'ieee-be');
images = fread(fp, inf, 'unsigned char');
images = reshape(images, numCols, numRows, numImages);
images = permute(images,[2 1 3]);
fclose(fp);
% Reshape to #pixels x #examples
images = reshape(images, size(images, 1) * size(images, 2), size(images, 3));
% Convert to double and rescale to [0,1]
images = double(images) / 255;
end
4)LoadMNISTLabels.m
function labels = loadMNISTLabels(filename)
%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing
%the labels for the MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2049, ['Bad magic number in ', filename, '']);
numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');
labels = fread(fp, inf, 'unsigned char');
assert(size(labels,1) == numLabels, 'Mismatch in label count');
fclose(fp);
end