本文使用tensorflow实现了mnist手写数字识别。对用惯了c++/c的我来说,python属实用着不舒服,很多其他人看着方便的东西对我来说有点折磨,再加上python第三方库了解的太少,只能凑合着把代码憋出来了。本代码分两部分,一部分是使用c++/c easyx制作的测试平台,一部分是tensorflow网络本体。
写测试平台就是为了验证神经网络的功能,所以做了一个类似画板的窗口,用鼠标在窗口写数字,然后写进文件中。从pycharm中再打开文件取出数据,喂给神经网络,获得结果。
整个测试程序的思路是把280*280像素的画布分成28*28个正方形,对应mnist数据集28*28像素的图片,每个正方形有10*10个像素。以正方形为单位绘图,这样可以使整体更加可视,当鼠标停在一个正方形区域,就将整个正方形涂黑。
鼠标可以像画笔一样在白布上写字画图,点击clear会清空画布,点击accept会把画布中的内容写进文件中。paint.txt中保存的是画布二进制信息;paint_test中保存的是可视画布信息,用来查看保存信息是否有误。
#pragma comment(lib,"python35.lib")
#include
#include
#include
#include
#include
//#define test
//#define test_graph
//#define mouse
//#define debug_pos
#define debug_file
#define CLE 785
#define ACC 786
using namespace std;
int pos(int x, int y);
void send_file();
void cle_paint();
void draw(int x);
bool paint[30][30] = {FALSE};
volatile int is = 0;
#ifdef test
volatile int is = 0;
#endif
DWORD WINAPI threadFun(LPVOID lpParamter) { //本来想写个多线程,子线程和python程序进行通讯
//后来因为bug不好暂时调放弃了,整个threadFun函数都不用看
#ifdef test
while (1) {
if (is == 1) {
cout << "child get!" << endl;
is = 0;
}
}
#endif
#ifdef error //这部分是bug程序,可以删除
Py_Initialize();
if (!Py_IsInitialized()) {
printf("initialized fail!\n");
return 0;
}
PyRun_SimpleString("import sys");
PyRun_SimpleString("sys.path.append('Dlls/')");
PyObject* pModule = NULL;
PyObject* pFunc = NULL;
pModule = PyImport_ImportModule("a");
if (pModule == NULL)
{
cout << "没找到" << endl;
}
/*
pFunc = PyObject_GetAttrString(pModule, "add");
PyObject* pParams = Py_BuildValue("ii", 2, 1);
PyObject* pResult = PyObject_CallObject(pFunc, pParams);
int a = PyLong_AsLong(pResult);
cout << a << endl;
*/
while (1) {
if (is == 1) {
pFunc = PyObject_GetAttrString(pModule, "init");
PyObject* pParams = Py_BuildValue("");
PyObject* pResult = PyObject_CallObject(pFunc, pParams);
is = 0;
}
}
Py_Finalize();
#endif
return 0;
}
int main(void) {
HANDLE threadA = CreateThread(NULL, 0, threadFun, NULL, 0, NULL);
#ifdef test //多线程测试
int sig = 0;
while (1) {
cout << "please input 0\\1:";
cin >> sig;
if (sig != 0 && sig != 1) {
cout << "please input again!\n";
continue;
}
if (sig == 1) {
is = 1;
}
else {
is = 0;
}
}
#endif
#ifdef test_graph //easyx画图功能测试
initgraph(640, 480);
MOUSEMSG m;
while (true) {
m = GetMouseMsg();
switch (m.uMsg) {
case WM_MOUSEMOVE:
putpixel(m.x, m.y, RED);
break;
case WM_LBUTTONDOWN:
if (m.mkCtrl)
rectangle(m.x - 10, m.y - 10, m.x + 10, m.y + 10);
else
rectangle(m.x - 5, m.y - 5, m.x + 5, m.y + 5);
break;
case WM_RBUTTONUP:
return 0;
}
}
closegraph();
#endif
initgraph(280, 300); //easyx画图初始化设置
setbkcolor(WHITE);
cleardevice();
setcolor(YELLOW);
setfillcolor(YELLOW);
fillrectangle(0, 280, 140, 300);
setcolor(GREEN);
setfillcolor(GREEN);
fillrectangle(140, 280, 280, 300);
settextcolor(BLACK);
RECT cle = { 0, 280, 140, 300 };
drawtext(_T("clear"), &cle, DT_CENTER | DT_VCENTER | DT_SINGLELINE);
RECT acc = { 140, 280, 280, 300 };
drawtext(_T("accept"), &acc, DT_CENTER | DT_VCENTER | DT_SINGLELINE);
MOUSEMSG m;
bool flag = false;
int pos_mod = 0;
while (1) {
m = GetMouseMsg();
#ifdef mouse //easyx鼠标事件测试
switch (m.uMsg) {
case WM_LBUTTONDOWN:
putpixel(m.x, m.y, BLACK);
flag = true;
break;
case WM_MOUSEMOVE:
if (flag == true) {
putpixel(m.x, m.y, BLACK);
}
break;
case WM_LBUTTONUP:
flag = false;
break;
}
#endif
#ifdef debug_pos //easyx鼠标事件测试
switch (m.uMsg) {
case WM_LBUTTONDOWN:
pos_mod = pos(m.x, m.y);
int yy = (pos_mod / 28) * 10;
int xx = (pos_mod % 28) * 10;
int xxx = xx + 10;
int yyy = yy + 10;
setcolor(BLACK);
setfillcolor(BLACK);
fillrectangle(xx, yy, xxx, yyy);
break;
}
#endif
switch (m.uMsg) { //画图程序
case WM_LBUTTONDOWN:
pos_mod = pos(m.x, m.y);
if (pos_mod == ACC) {
send_file();
}
else if (pos_mod == CLE) {
cle_paint();
}
else {
flag = true;
draw(pos_mod);
}
break;
case WM_MOUSEMOVE:
pos_mod = pos(m.x, m.y);
if (flag == true) {
draw(pos_mod);
}
break;
case WM_LBUTTONUP:
flag = false;
break;
}
}
closegraph();
return 0;
}
int pos(int x, int y) { //获得鼠标当前所在位置
int h = y / 10;
int a = (h * 28) + (x / 10);
if ((a > 784) && x <= 140) {
return CLE;
}
if ((a > 784) && x > 140) {
return ACC;
}
}
void send_file() { //保存画布数据并写进文件
unsigned char b[100];
int zz = 0;
for (int i = 0; i < 98; i++) {
for (int j = 0; j < 8; j++) {
b[i] = (b[i] << 1) + paint[zz / 28][zz % 28];
zz++;
}
}
FILE* fd;
if ((fd = fopen("D:\\a_c_obj\\c2py\\paint.txt", "wb")) == NULL) {
printf("file open fail!\n");
return;
}
fwrite(b, sizeof(unsigned char), 98, fd);
fclose(fd);
#ifdef debug_file
FILE* fd2;
if ((fd2 = fopen("D:\\a_c_obj\\c2py\\paint_test.txt", "wb")) == NULL) {
printf("file open fail!\n");
return;
}
unsigned char t[4] = { '0', '1' , '\n'};
for (int i = 0; i < 28; i++) {
for (int j = 0; j < 28; j++) {
if (paint[i][j]) {
fwrite(t, sizeof(unsigned char), 1, fd2);
}
else {
fwrite(t + 1, sizeof(unsigned char), 1, fd2);
}
}
fwrite(t + 2, sizeof(unsigned char), 1, fd2);
}
fclose(fd2);
#endif
Sleep(500);
is = 1;
return;
}
void cle_paint() { //清空画布
clearrectangle(0, 0, 280, 280);
memset(paint, 0, sizeof(paint));
}
void draw(int pos_mod) { //画点程序
int yy = (pos_mod / 28) * 10;
int xx = (pos_mod % 28) * 10;
int xxx = xx + 10;
int yyy = yy + 10;
int x2 = (pos_mod / 28) + 1;
int y2 = (pos_mod % 28) + 1;
paint[x2][y2] = 1;
setcolor(BLACK);
setfillcolor(BLACK);
fillrectangle(xx, yy, xxx, yyy);
}
图1 测试平台效果图
第二部分是神经网络
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
session = tf.Session()
def init():
global session
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)
Weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([1,10]))
predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) + biases)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
session.run(init)
istrain = 0
saver = tf.train.Saver()
checkpoint_dir = 'D:\\untitled3\\'
if istrain:
for i in range(500):
images, labels = mnist.train.next_batch(batch_size)
session.run(train, feed_dict={X_holder:images, y_holder:labels})
if i % 25 == 0:
correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
print('step:%d accuracy:%.4f' %(i, accuracy_value))
saver.save(session, checkpoint_dir + 'model.ckpt')
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(session, ckpt.model_checkpoint_path)
else:
pass
file_object = open('D:\\a_c_obj\\c2py\\paint.txt', "rb")
try:
all_the_text = file_object.read()
finally:
file_object.close()
mask = [0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, 0b01000000, 0b10000000]
paint = []
#print(type(all_the_text))
for byte in all_the_text:
for i in range(8):
if (byte & mask[i]) == 0:
paint.append(0.0)
else:
paint.append(1.0)
#aa = tf.constant(paint, dtype=tf.float32)
d = np.array(paint, dtype='float32').reshape(1, 784)
result = session.run(predict_y, feed_dict={X_holder: d})
print(result)
max = 0.0
res = 0
for k in range(10):
if (max < result[0][k]):
max = result[0][k]
#print(max)
res = k
print(res)
res1 = str(res)
print(res1)
file_object = open('D:\\a_c_obj\\c2py\\result.txt', "w")
try:
file_object.write(res1)
finally:
file_object.close()
init()
训练时将istrain设置为1,验证时设置为0。
运行结果:
需要说明的是,由于测试工具简陋,神经网络小且简单,最后测试结果惨不忍睹,虽然在mnist数据集表现良好,但实际结果很差。后来分析了一下,觉得是我的测试工具的锅,最近在学一些抗锯齿算法,以及从各方面升级一下测试工具,应该会有较大的改进。