【手写数字识别KNN】记录

pycharm中引入tensorflow库

1. expand_dims()函数作用

【手写数字识别KNN】记录_第1张图片

2. tf.reduce_sum()函数作用

【手写数字识别KNN】记录_第2张图片

3. tf.argmax()函数作用

4. 附上代码

# 1 load data    随机数加载
# 2 knn test train distance
# 3 knn 中k个最近的图片from 500 tarinData Pictures according to 计算出来的distance
# 4 parse content解析k个最近图片中的标签数字内容
# 5 数字<=Label
# 6 识别正确率统计

import tensorflow as tf
import numpy as np
import random
# is used to read data
from tensorflow.examples.tutorials.mnist import input_data

# load data  1.fileName 2.one_hot  1 0000000
mnist = input_data.read_data_sets('C:\\Users\\Administrator\\Desktop\\mnist', one_hot=True)
trainNum = 55000
testNum = 10000
trainSize = 500
testSize = 5
k = 4
# data分解 replace=False 为不可重复,这样随机取train、test样本
trainIndex = np.random.choice(trainNum, trainSize, replace=False)
testIndex = np.random.choice(testNum, testSize, replace=False)
trainData = mnist.train.images[trainIndex]
trainLabel = mnist.train.labels[trainIndex]
testData = mnist.test.images[testIndex]
testLabel = mnist.test.labels[testIndex]
print('trainData.shape=', trainData.shape)
print('testLabel=', testLabel)
# tf input
trainDataInput = tf.placeholder(shape=[None, 784], dtype=tf.float32)
trainLabelInput = tf.placeholder(shape=[None, 10], dtype=tf.float32)
testDataInput = tf.placeholder(shape=[None, 784], dtype=tf.float32)
testLabelInput = tf.placeholder(shape=[None, 10], dtype=tf.float32)
# knn distance
# 5 500 784  (3D)=2500*784
f1 = tf.expand_dims(testDataInput, 1)  # 维度扩展5*784=>5*1*784
f2 = tf.subtract(trainDataInput, f1)   # f2(5*500*784)=trainDataInput(500*784)-f1(5*1*784)
f3 = tf.reduce_sum(tf.abs(f2), reduction_indices=2)  # 完成数据累加 f3(5*500)<=f2(5*500*784)第二维累加
# f3:5*500 测试图片和训练图片的差值计算结果
f4 = tf.negative(f3)  # 取反
f5, f6 = tf.nn.top_k(f4, k=10)  # 选取f4中最大的四个值即f3中最小的4个值
# f6 4个最近图片的下标
f7 = tf.gather(trainLabelInput, f6)
f8 = tf.reduce_sum(f7, reduction_indices=1)
# tf.argmax取出相似图片Label最集中的Label
f9 = tf.argmax(f8, dimension=1)

with tf.Session() as sess:
    p1 = sess.run(f1, feed_dict={testDataInput: testData[0:5]})
    print('p1=', p1.shape)  # p1=(5,1,784)
    p2 = sess.run(f2, feed_dict={trainDataInput: trainData, testDataInput: testData[0:5]})
    print('p2=', p2.shape)  # p2=(5,500,784)
    p3 = sess.run(f3, feed_dict={trainDataInput: trainData, testDataInput: testData[0:5]})
    print('p3=', p3.shape)
    print('p3[0,0]=', p3[0, 0])  # knn distance
    p4 = sess.run(f4, feed_dict={trainDataInput: trainData, testDataInput: testData[0:5]})
    print('p4=', p4.shape)
    print('p4[0,0]=', p4[0, 0])
    p5, p6 = sess.run((f5, f6), feed_dict={trainDataInput: trainData, testDataInput: testData[0:5]})
    # p5=(5,4) 每一张测试图片跟距离最近的4张训练图片的像素差值
    # p6=(5,4) 每一张测试图片距离最近的4张训练图片的下标
    print('p5=', p5.shape)
    print('p6=', p6.shape)
    print('p5[0,0]=', p5[0, 0])
    print('p6[0]=', p6[0])

    p7 = sess.run(f7, feed_dict={trainDataInput: trainData, testDataInput: testData[0:5], trainLabelInput: trainLabel})
    print('p7=', p7.shape)  # p7=(5,4,10)
    print('p7[]=', p7)

    p8 = sess.run(f8, feed_dict={trainDataInput: trainData, testDataInput: testData[0:5], trainLabelInput: trainLabel})
    print('p8=', p8.shape)  # p7=(5,4,10)
    print('p8[]=', p8)

    p9 = sess.run(f9, feed_dict={trainDataInput: trainData, testDataInput: testData[0:5], trainLabelInput: trainLabel})
    print('p9=', p9.shape)  # p7=(5,4,10)
    print('p9[]=', p9)

    p10 = np.argmax(testLabel[0:5], axis=1)
    print('p10=', p10)

    # 测算识别正确率
j = 0
for i in range(0, 5):
    if p10[i] == p9[i]:
        j = j + 1
print('ac=', j / 5*100)

 

你可能感兴趣的:(【手写数字识别KNN】记录)