mnist手写数字识别tensorflow实现和测试

        本文使用tensorflow实现了mnist手写数字识别。对用惯了c++/c的我来说,python属实用着不舒服,很多其他人看着方便的东西对我来说有点折磨,再加上python第三方库了解的太少,只能凑合着把代码憋出来了。本代码分两部分,一部分是使用c++/c easyx制作的测试平台,一部分是tensorflow网络本体。

        写测试平台就是为了验证神经网络的功能,所以做了一个类似画板的窗口,用鼠标在窗口写数字,然后写进文件中。从pycharm中再打开文件取出数据,喂给神经网络,获得结果。

        整个测试程序的思路是把280*280像素的画布分成28*28个正方形,对应mnist数据集28*28像素的图片,每个正方形有10*10个像素。以正方形为单位绘图,这样可以使整体更加可视,当鼠标停在一个正方形区域,就将整个正方形涂黑。

        鼠标可以像画笔一样在白布上写字画图,点击clear会清空画布,点击accept会把画布中的内容写进文件中。paint.txt中保存的是画布二进制信息;paint_test中保存的是可视画布信息,用来查看保存信息是否有误。

#pragma comment(lib,"python35.lib")
#include 
#include 
#include 
#include 
#include 

//#define test
//#define test_graph
//#define mouse
//#define debug_pos
#define debug_file
#define CLE 785
#define ACC 786

using namespace std;

int pos(int x, int y);
void send_file();
void cle_paint();
void draw(int x);

bool paint[30][30] = {FALSE};
volatile int is = 0;

#ifdef test
volatile int is = 0;
#endif

DWORD WINAPI threadFun(LPVOID lpParamter) {  //本来想写个多线程,子线程和python程序进行通讯
                                     //后来因为bug不好暂时调放弃了,整个threadFun函数都不用看
#ifdef test                                                
    while (1) {
        if (is == 1) {
            cout << "child get!" << endl;
            is = 0;
        }
    }
#endif

#ifdef error   //这部分是bug程序,可以删除

    Py_Initialize();
    if (!Py_IsInitialized()) {
        printf("initialized fail!\n");
        return 0;
    }

    PyRun_SimpleString("import sys");
    PyRun_SimpleString("sys.path.append('Dlls/')");

    PyObject* pModule = NULL;
    PyObject* pFunc = NULL;
    pModule = PyImport_ImportModule("a");
    if (pModule == NULL)
    {
        cout << "没找到" << endl;
    }

    /*
    pFunc = PyObject_GetAttrString(pModule, "add");
    PyObject* pParams = Py_BuildValue("ii", 2, 1);
    PyObject* pResult = PyObject_CallObject(pFunc, pParams);
    int a = PyLong_AsLong(pResult);
    cout << a << endl;
    */

    while (1) {
        if (is == 1) {
            
            pFunc = PyObject_GetAttrString(pModule, "init");
            PyObject* pParams = Py_BuildValue("");
            PyObject* pResult = PyObject_CallObject(pFunc, pParams);
            is = 0;

        }
    }

    Py_Finalize();

#endif

    return 0;
}

int main(void) {

    HANDLE threadA = CreateThread(NULL, 0, threadFun, NULL, 0, NULL);

#ifdef test   //多线程测试
    int sig = 0;
    while (1) {
        cout << "please input 0\\1:";
        cin >> sig;
        if (sig != 0 && sig != 1) {
            cout << "please input again!\n";
            continue;
        }
        if (sig == 1) {
            is = 1;
        }
        else {
            is = 0;
        }
    }
#endif

#ifdef test_graph   //easyx画图功能测试
    initgraph(640, 480);

    MOUSEMSG m;

    while (true) {
        m = GetMouseMsg();
        switch (m.uMsg) {
        case WM_MOUSEMOVE:
            putpixel(m.x, m.y, RED);
            break;
        case WM_LBUTTONDOWN:
            if (m.mkCtrl)
                rectangle(m.x - 10, m.y - 10, m.x + 10, m.y + 10);
            else
                rectangle(m.x - 5, m.y - 5, m.x + 5, m.y + 5);
            break;
        case WM_RBUTTONUP:
            return 0;
        }
    }
    closegraph();
#endif
 
    initgraph(280, 300);    //easyx画图初始化设置
    setbkcolor(WHITE);
    cleardevice();

    setcolor(YELLOW);
    setfillcolor(YELLOW);
    fillrectangle(0, 280, 140, 300);
    setcolor(GREEN);
    setfillcolor(GREEN);
    fillrectangle(140, 280, 280, 300);

    settextcolor(BLACK);
    RECT cle = { 0, 280, 140, 300 };
    drawtext(_T("clear"), &cle, DT_CENTER | DT_VCENTER | DT_SINGLELINE);
    RECT acc = { 140, 280, 280, 300 };
    drawtext(_T("accept"), &acc, DT_CENTER | DT_VCENTER | DT_SINGLELINE);

    MOUSEMSG m;
    bool flag = false;
    int pos_mod = 0;

    while (1) {

        m = GetMouseMsg();

#ifdef mouse   //easyx鼠标事件测试
        switch (m.uMsg) {
        case WM_LBUTTONDOWN:
            putpixel(m.x, m.y, BLACK);
            flag = true;
            break;
        case WM_MOUSEMOVE:
            if (flag == true) {
                putpixel(m.x, m.y, BLACK);
            }
            break;
        case WM_LBUTTONUP:
            flag = false;
            break;
        }
#endif

#ifdef debug_pos    //easyx鼠标事件测试
        switch (m.uMsg) {
        case WM_LBUTTONDOWN:
            pos_mod = pos(m.x, m.y);
            int yy = (pos_mod / 28) * 10;
            int xx = (pos_mod % 28) * 10;
            int xxx = xx + 10;
            int yyy = yy + 10;
            setcolor(BLACK);
            setfillcolor(BLACK);
            fillrectangle(xx, yy, xxx, yyy);
            break;
        }
#endif
        
        switch (m.uMsg) {      //画图程序
        case WM_LBUTTONDOWN:
            pos_mod = pos(m.x, m.y);
            if (pos_mod == ACC) {
                send_file();
            }
            else if (pos_mod == CLE) {
                cle_paint();
            }
            else {
                flag = true;
                draw(pos_mod);
            }
            break;
        case WM_MOUSEMOVE:
            pos_mod = pos(m.x, m.y);
            if (flag == true) {
                draw(pos_mod);
            }
            break;
        case WM_LBUTTONUP:
            flag = false;
            break;
        }
        
    }
    
    closegraph();

    return 0;

}

int pos(int x, int y) {   //获得鼠标当前所在位置

    int h = y / 10;
    int a = (h * 28) + (x / 10);
    if ((a > 784) && x <= 140) {
        return CLE;
    }
    if ((a > 784) && x > 140) {
        return ACC;
    }

}

void send_file() {     //保存画布数据并写进文件

    unsigned char b[100];
    int zz = 0;
    for (int i = 0; i < 98; i++) {
        for (int j = 0; j < 8; j++) {
            b[i] = (b[i] << 1) + paint[zz / 28][zz % 28];
            zz++;
        }
    }
    FILE* fd;
    if ((fd = fopen("D:\\a_c_obj\\c2py\\paint.txt", "wb")) == NULL) {
        printf("file open fail!\n");
        return;
    }
    fwrite(b, sizeof(unsigned char), 98, fd);
    fclose(fd);

#ifdef debug_file
    FILE* fd2;
    if ((fd2 = fopen("D:\\a_c_obj\\c2py\\paint_test.txt", "wb")) == NULL) {
        printf("file open fail!\n");
        return;
    }
    unsigned char t[4] = { '0', '1' , '\n'};
    for (int i = 0; i < 28; i++) {
        for (int j = 0; j < 28; j++) {
            if (paint[i][j]) {
                fwrite(t, sizeof(unsigned char), 1, fd2);
            }
            else {
                fwrite(t + 1, sizeof(unsigned char), 1, fd2);
            }
        }
        fwrite(t + 2, sizeof(unsigned char), 1, fd2);
    }
    fclose(fd2);
#endif

    Sleep(500);

    is = 1;

    return;

}

void cle_paint() {   //清空画布

    clearrectangle(0, 0, 280, 280);
    memset(paint, 0, sizeof(paint));

}

void draw(int pos_mod) {   //画点程序

    int yy = (pos_mod / 28) * 10;
    int xx = (pos_mod % 28) * 10;
    int xxx = xx + 10;
    int yyy = yy + 10;
    int x2 = (pos_mod / 28) + 1;
    int y2 = (pos_mod % 28) + 1;
    paint[x2][y2] = 1;
    setcolor(BLACK);
    setfillcolor(BLACK);
    fillrectangle(xx, yy, xxx, yyy);

}

        mnist手写数字识别tensorflow实现和测试_第1张图片

 图1 测试平台效果图

        第二部分是神经网络


import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

session = tf.Session()

def init():
    global session
    mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
    batch_size = 100
    X_holder = tf.placeholder(tf.float32)
    y_holder = tf.placeholder(tf.float32)

    Weights = tf.Variable(tf.zeros([784, 10]))
    biases = tf.Variable(tf.zeros([1,10]))
    predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) + biases)
    loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)

    init = tf.global_variables_initializer()
    session.run(init)

    istrain = 0
    saver = tf.train.Saver()
    checkpoint_dir = 'D:\\untitled3\\'

    if istrain:
        for i in range(500):
            images, labels = mnist.train.next_batch(batch_size)
            session.run(train, feed_dict={X_holder:images, y_holder:labels})
            if i % 25 == 0:
                correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
                accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
                accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
                print('step:%d accuracy:%.4f' %(i, accuracy_value))
        saver.save(session, checkpoint_dir + 'model.ckpt')
    else:
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(session, ckpt.model_checkpoint_path)
        else:
            pass

    file_object = open('D:\\a_c_obj\\c2py\\paint.txt', "rb")
    try:
         all_the_text = file_object.read()
    finally:
         file_object.close()
    mask = [0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, 0b01000000, 0b10000000]
    paint = []
    #print(type(all_the_text))
    for byte in all_the_text:
        for i in range(8):
            if (byte & mask[i]) == 0:
                paint.append(0.0)
            else:
                paint.append(1.0)

    #aa = tf.constant(paint, dtype=tf.float32)
    d = np.array(paint, dtype='float32').reshape(1, 784)

    result = session.run(predict_y, feed_dict={X_holder: d})
    print(result)

    max = 0.0
    res = 0
    for k in range(10):
        if (max < result[0][k]):
            max = result[0][k]
            #print(max)
            res = k
    print(res)
    res1 = str(res)
    print(res1)
    file_object = open('D:\\a_c_obj\\c2py\\result.txt', "w")
    try:
        file_object.write(res1)
    finally:
        file_object.close()

init()


        训练时将istrain设置为1,验证时设置为0。

        运行结果:

mnist手写数字识别tensorflow实现和测试_第2张图片

        需要说明的是,由于测试工具简陋,神经网络小且简单,最后测试结果惨不忍睹,虽然在mnist数据集表现良好,但实际结果很差。后来分析了一下,觉得是我的测试工具的锅,最近在学一些抗锯齿算法,以及从各方面升级一下测试工具,应该会有较大的改进。 

你可能感兴趣的:(tensorflow,python,tensorflow)