Tensorflow.js运行Python下训练的CNN模型(创新实训日志一)

我们的项目计划利用Python来训练模型,然后在浏览器中去调用训练好的模型,因为Python环境下读取数据、GPU加速等都比较容易实现,所以就需要解决一下Python训练好的模型在移植到js环境下的问题。幸运的是有现成的工具可以使用。

之前写过一个简单的demo来尝试将python下训练的模型放到浏览器环境中去运行,但时间比较久远,工具出现了版本的升级,以前的方法可能会出现一些问题,此外当时只是一个非常简单的NN模型,而我们使用的是一个较为复杂的CNN模型。所以本次的主要任务是探究训练好的CNN模型在浏览器上的移植情况,最终需要实现浏览器来成功运行CNN模型,这里为了简单起见我采用了MNIST数据集来进行训练和预测。

一、安装0.8.5版本的tensorflowjs

我试验了1.0.1和0.5.6等几个版本,最后使用这个版本成功导出了可以在浏览器中运行的模型。

sudo pip install tensorflowjs==0.8.5

顺便说一下我的tensorflow版本是1.8.0。

二、使用python实现CNN模型

这里按照书上的示例实现了一个作用在MNIST上的CNN代码,代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import tensorflow as tf
mnist=input_data.read_data_sets("MNIST",one_hot=True)
trX,trY,teX,teY=mnist.train.images,mnist.train.labels,mnist.test.images,mnist.test.labels
trX=trX.reshape(-1,28,28,1)
teX=teX.reshape(-1,28,28,1)
X=tf.placeholder("float",[None,28,28,1])	#稍后用于输入Tensor的placeholder
Y=tf.placeholder("float",[None,10])
def init_weights(shape):
	return tf.Variable(tf.random_normal(shape,stddev=0.01))
w=init_weights([3,3,1,32])
w2=init_weights([3,3,32,64])
w3=init_weights([3,3,64,128])
w4=init_weights([128*4*4,625])	
w_o=init_weights([625,10])
def model(X,w,w2,w3,w4,w_o,p_keep_conv=1.0,p_keep_hidden=1.0):
	lla=tf.nn.relu(tf.nn.conv2d(X,w,strides=[1,1,1,1],padding='SAME'))
	l1=tf.nn.max_pool(lla,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
	l1=tf.nn.dropout(l1,p_keep_conv)
	l2a=tf.nn.relu(tf.nn.conv2d(l1,w2,strides=[1,1,1,1],padding='SAME'))
	l2=tf.nn.max_pool(l2a,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
	l2=tf.nn.dropout(l2,p_keep_conv)
	l3a=tf.nn.relu(tf.nn.conv2d(l2,w3,strides=[1,1,1,1],padding='SAME'))
	l3=tf.nn.max_pool(l3a,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
	l3=tf.reshape(l3,[-1,w4.get_shape().as_list()[0]])
	l3=tf.nn.dropout(l3,p_keep_conv)
	l4=tf.nn.relu(tf.matmul(l3,w4))
	l4=tf.nn.dropout(l4,p_keep_hidden)
	pyx=tf.matmul(l4,w_o)
	return pyx	
py_x=model(X,w,w2,w3,w4,w_o)	
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x,labels=Y))
train_op=tf.train.RMSPropOptimizer(0.001,0.9).minimize(cost)
predict_op=tf.argmax(py_x,1,name="model")	#预测操作的模型名称
batch_size=128
test_size=256
with tf.Session() as sess:
	tf.global_variables_initializer().run()
	for i in range(3):	
		training_batch=zip(range(0,len(trX),batch_size),range(batch_size,len(trX)+1,batch_size))
		for start,end in training_batch:
			sess.run(train_op,feed_dict={X:trX[start:end],Y:trY[start:end]})
		test_indices=np.arange(len(teX))	
		np.random.shuffle(test_indices)
		test_indices=test_indices[0:test_size]
		print(i,np.mean(np.argmax(teY[test_indices],axis=1)==sess.run(predict_op,feed_dict={X:teX[test_indices]})))
		print(sess.run(predict_op,feed_dict={X:teX[[50]]}))
		print(teY[50])
	tf.saved_model.simple_save(sess, "./saved_model",inputs={"x": X, }, outputs={"model": predict_op, })	#保存模型,第一个参数是会话对象,第二个参数是要输出的文件夹名,第三个参数描述了输入,第四个参数描述了输出

为了节省篇幅删去了大部分代码注释,因为重点不在CNN的实现上,此时有注释的地方为转换模型是需要重点注意的地方。

最后的tf.saved_model是实际的保存操作,它会将计算图和我们训练的权重一并保存起来,需要注意的是第三个参数的value一定是输入对应的placeholder,在本代码中对应的就是X,第四个参数则对应计算输出的操作,在本代码中对应着predict_op,它的名称必须与第四个参数的key一致,value则是操作本身。

执行这段代码之后在目录下会出现训练好的模型文件夹"saved_model":

Tensorflow.js运行Python下训练的CNN模型(创新实训日志一)_第1张图片

三、导出model.json

当我在浏览器中运行之前的实验导出的模型(https://blog.csdn.net/zekdot/article/details/82913636) 的时候出现了报错,根据报错的提示可以知道,现在推荐使用json而不是pb格式的导出模型了:

1

查找了很多资料之后我发现原来这个问题只需要在导出的时候增加一个参数就能解决了。。命令如下:

tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names 'model' \
--saved_model_tags=serve \
--output_json model.json \
./saved_model \
./web_model

执行之后可以发现当前目录下多出来一个web_model:

Tensorflow.js运行Python下训练的CNN模型(创新实训日志一)_第2张图片

这也是成功转换之后的文件夹,这里面有如下几个文件,也就是可以直接利用js调用的模型:

Tensorflow.js运行Python下训练的CNN模型(创新实训日志一)_第3张图片

四、在浏览器中运行导出的模型

首先编写对应的html文件,结果都在终端中进行输出。

这里为了简便,我直接取了一个处理好的MNIST数据数组来进行测试,其对应的标签是5,否则就需要大量的js代码来对图像进行若干处理才能进行识别。

完整代码如下:

<!doctype html>
<html lang="en">
<head>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"> </script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script>
</head>
<body>
  <script>
const MODEL_URL = './model.json'	//模型文件名
async function fun(){	//预测函数
    const model=await tf.loadGraphModel(MODEL_URL);	//加载图模型
    const cs = tf.tensor([[[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.03529412],[0.3019608],[0.19607845],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.00784314],[0.24705884],[0.4666667],[0.67058825],[0.9490197],[0.95294124],[0.79215693],[0.38431376],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.19215688],[0.08235294],[0.0],[0.09411766],[0.4666667],[0.8000001],[0.9960785],[1.],[0.89019614],[0.5019608],[0.41960788],[0.02745098],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.11764707],[0.73333335],[0.9960785],[0.909804],[0.8431373],[0.9176471],[0.9960785],[0.9450981],[0.62352943],[0.2901961],[0.07058824],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.07058824],[0.85098046],[0.9960785],[0.9960785],[0.9960785],[0.9843138],[0.86666673],[0.41176474],[0.13333334],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.07058824],[0.7803922],[0.9960785],[0.9960785],[0.97647065],[0.63529414],[0.25882354],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.73333335],[0.9960785],[0.9960785],[0.7686275],[0.12156864],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.22352943],[0.94117653],[0.9960785],[0.6392157],[0.00784314],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.8705883],[0.9960785],[0.6313726],[0.01960784],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.13333334],[0.9490197],[0.9725491],[0.29803923],[0.13333334],[0.41960788],[0.5254902],[0.854902],[0.69411767],[0.3647059],[0.10196079],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.45882356],[0.9960785],[0.9725491],[0.9450981],[0.95294124],[0.9960785],[0.9215687],[0.8745099],[0.94117653],[0.9960785],[0.92549026],[0.40000004],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.77647066],[0.9960785],[0.9960785],[0.78823537],[0.5647059],[0.227451],[0.08627451],[0.0],[0.12156864],[0.5568628],[0.9803922],[0.91372555],[0.40784317],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.07058824],[0.5921569],[0.33333334],[0.05882353],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.5176471],[0.9960785],[0.6862745],[0.00392157],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.42352945],[0.9960785],[0.9960785],[0.03529412],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.12941177],[0.8196079],[0.9960785],[0.6901961],[0.00392157],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.2509804],[0.8941177],[0.9960785],[0.91372555],[0.1764706],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.05490196],[0.38431376],[0.7137255],[0.9803922],[0.9960785],[0.9058824],[0.30980393],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.10980393],[0.10588236],[0.10588236],[0.4039216],[0.45882356],[0.5921569],[0.8705883],[0.9960785],[0.9960785],[0.94117653],[0.5568628],[0.10980393],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.9490197],[0.97647065],[0.9803922],[0.9921569],[0.9686275],[0.9607844],[0.9450981],[0.72156864],[0.3921569],[0.11764707],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.18039216],[0.19215688],[0.25490198],[0.13333334],[0.07843138],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]],[[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0],[0.0]]]])	//测试数据tensor,标签为5
	cs.print()	//打印测试数据tensor
	model.predict(cs).print()	//进行预测并输出预测结果
}
fun()	//调用函数
  </script>
</body>
</html>

编写好代码之后,将之前web_model中的所有文件复制到编写好的html文件同级目录下,如图:

Tensorflow.js运行Python下训练的CNN模型(创新实训日志一)_第4张图片

然后将tensorflowTest1.html文件夹拖到浏览器中运行,按F12打开控制台可以观察到结果:

Tensorflow.js运行Python下训练的CNN模型(创新实训日志一)_第5张图片
可见代码成功得以运行,即成功在浏览器中利用CNN实现了手写体的识别。

你可能感兴趣的:(创新实训,Tensorflow.js,CNN)