这里直接给出KNN matlab的实现
trainImages = loadMNISTImages('train-images.idx3-ubyte');
trainLabels = loadMNISTLabels('train-labels.idx1-ubyte');
N = 784;
K = 100;% can be any other value
testImages = loadMNISTImages('t10k-images.idx3-ubyte');
testLabels = loadMNISTLabels('t10k-labels.idx1-ubyte');
trainLength = length(trainImages);
testLength = length(testImages);
testResults = linspace(0,0,length(testImages));
compLabel = linspace(0,0,K);
tic;
for i=1:testLength
curImage = repmat(testImages(:,i),1,trainLength);
curImage = abs(trainImages-curImage);
comp=sum(curImage);
[sortedComp,ind] = sort(comp);
for j = 1:K
compLabel(j) = trainLabels(ind(j));
end
table = tabulate(compLabel);
[maxCount,idx] = max(table(:,2));
testResults(i) = table(idx);
disp(testResults(i));
disp(testLabels(i));
end
% Compute the error on the test set
error=0;
for i=1:testLength
if (testResults(i) ~= testLabels(i))
error=error+1;
end
end
%Print out the classification error on the test set
error/testLength
toc;
disp(toc-tic);
其中训练数据60000条,测试数据10000条
运行时间慢的原因分析:
没有进行主成分分析,用所有的维度在进行比较,这一点是可以改进的地方:)
附上其他次要代码:
function images = loadMNISTImages(filename)
%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing
%the raw MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2051, ['Bad magic number in ', filename, '']);
numImages = fread(fp, 1, 'int32', 0, 'ieee-be');
numRows = fread(fp, 1, 'int32', 0, 'ieee-be');
numCols = fread(fp, 1, 'int32', 0, 'ieee-be');
images = fread(fp, inf, 'unsigned char');
images = reshape(images, numCols, numRows, numImages);
images = permute(images,[2 1 3]);
fclose(fp);
% Reshape to #pixels x #examples
images = reshape(images, size(images, 1) * size(images, 2), size(images, 3));
% Convert to double and rescale to [0,1]
images = double(images) / 255;
end
function labels = loadMNISTLabels(filename)
%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing
%the labels for the MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2049, ['Bad magic number in ', filename, '']);
numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');
labels = fread(fp, inf, 'unsigned char');
assert(size(labels,1) == numLabels, 'Mismatch in label count');
fclose(fp);
end