环境:
Anaconda3(64-bit),安装mxnet1.3.1,opencv_python-3.4.5.20-cp36-cp36m-win_amd64.whl(可选)
训练源码:
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 16:30:15 2019
@author: houwenbin
"""
import numpy as np
import mxnet as mx
import logging
logging.getLogger().setLevel(logging.DEBUG)
batch_size = 100
mnist = mx.test_utils.get_mnist()
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
data = mx.sym.var('data')
# first conv layer
conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1= mx.sym.Activation(data=conv1, act_type="tanh")
pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))
# second conv layer
conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2= mx.sym.Activation(data=conv2, act_type="tanh")
pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))
# first fullc layer
flatten= mx.sym.Flatten(data=pool2)
fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3= mx.sym.Activation(data=fc1, act_type="tanh")
# second fullc
fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)
# softmax loss
lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')
# create a trainable module on GPU 0
lenet_model = mx.mod.Module(
symbol=lenet,
context=mx.cpu())
# train with the same
lenet_model.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
batch_end_callback = mx.callback.Speedometer(batch_size, 100),
num_epoch=10)
# save model params
#lenet_model.save_params("lenet_10.params");
#
lenet_model.save_checkpoint("lenet", 10, False);
预测源码:
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 19 20:17:26 2019
@author: houwenbin
"""
import time
import mxnet as mx
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
#
prefix = 'lenet'
iteration = 10
img_name = './digit_8.jpg'
synsets = [0,1,2,3,4,5,6,7,8,9]
# imagenet 图像预处理
def load_image(img_name):
#PIL
#相关:scipy.misc.imread, scipy.ndimage.imread
#misc.imread 提供可选参数mode,但本质上是调用PIL,具体的模式可以去看srccode或者document
img = Image.open(img_name)
if img is None:
return None
img = img.resize((28,28))
img = np.array(img.convert('L'),'f') #读取图片,灰度化,转换为数组,L = 0.299R + 0.587G + 0.114B。'f'为float类型
#统一使用plt进行显示,不管是plt还是cv2.imshow,在python中只认numpy.array,但是由于cv2.imread 的图片是BGR,cv2.imshow 时相应的换通道显示
print(img.shape)
plt.imshow(img)
plt.show()
#
img = img.reshape(1,1,28,28).astype(np.float32)/255
return img
time0 = time.time()
# 加载 mxnet symbol
sym, arg, aux = mx.model.load_checkpoint(prefix, iteration)
# 重建模型
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data',(1,1,28,28))], label_shapes=mod._label_shapes) # 为输入数据分配内存
mod.set_params(arg, aux, allow_missing=True) # 加载模型参数
#
time1 = time.time()
print("模型加载和重建时间:{0}".format(time1 - time0))
#
#加载图片
img = load_image(img_name)
if img is None:
exit()
print(img.shape)
#
time0 = time.time()
#
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
#
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)])) # img{NDArray 1x1x28x28}做简单的inference
#
time1 = time.time()
print("前向预测时间:{0}".format(time1 - time0))
#输出Top-5预测结果
print(mod.get_outputs())
prob = mod.get_outputs()[0].asnumpy() #取出结果
print("-------result-------", prob, prob.shape)
prob = np.squeeze(prob)
print("-------squeeze result-------", prob, prob.shape)
print("-------sorted prob--------", np.sort(prob)) # 从小到大排列
print("-------arg sorted prob--------", np.argsort(prob))
a = np.argsort(prob)[::-1] # 得到分类网络分类置信度的从大到小的结果
print("------top sorted index-------", a, a.shape)
if a is not None:
for i in a[0:5]:
print('probability=%f, class=%s' %(prob[i], synsets[i]))
数据准备:
使用画图工具,绘制一个128x128的黑色背景,用橡皮擦擦除待检测数字即可(本文是digit_8.jpg)。
运行结果:
in[1]:runfile('C:/Users/houwenbin/Documents/PythonProject/test_mnist.py', wdir='C:/Users/houwenbin/Documents/PythonProject')
模型加载和重建时间:0.0060160160064697266
(1, 1, 28, 28)
前向预测时间:0.0010042190551757812
[
[[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01
4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]]
-------result------- [[3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01
4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05]] (1, 10)
-------squeeze result------- [3.0556594e-06 1.3175709e-06 4.1811345e-06 1.1044953e-08 9.9990916e-01
4.0004899e-10 3.0342795e-05 1.8727254e-05 2.3288235e-06 3.0720061e-05] (10,)
-------sorted prob-------- [4.0004899e-10 1.1044953e-08 1.3175709e-06 2.3288235e-06 3.0556594e-06
4.1811345e-06 1.8727254e-05 3.0342795e-05 3.0720061e-05 9.9990916e-01]
-------arg sorted prob-------- [5 3 1 8 0 2 7 6 9 4]
------top sorted index------- [4 9 6 7 2 0 8 1 3 5] (10,)
probability=0.999909, class=4
probability=0.000031, class=9
probability=0.000030, class=6
probability=0.000019, class=7
probability=0.000004, class=2
in[2]: