原文
Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预测放在前端,置于浏览器内。
本文会用一个简单的demo介绍如何从零开始训练一个tensorflow模型,并在浏览器内实现手写数字识别,最终效果大约如下:
手写数字识别示例
本文会假设你有基本的python和JavaScript的知识。项目的完整代码参考github。
项目代码的目录结构如下:
项目目录结构
整个结构大概分成server和web两个部分,分别是服务端和浏览器端的代码。
我们的流程大概如下:
我们需要的所有依赖如下:
python:
建议使用3.5以上的版本。我不能保证在<3.5的版本中它是否能正常工作。Tensorflow的兼容性问题一向令人头疼。注意在mac和linux上默认的python是python2。
这个demo内已经包含一个已经训练好的模型,所以你如果并不想自己再训练一次,可以不安装tensorflow和tensorflowjs。所有这些依赖都可以通过pip安装。
JavaScript:
你不需要特别安装任何东西,因为我们的库都是通过链接导入的。
浏览器:
反正在chrome浏览器里是能跑起来的……
项目文件里面已经包含了一个训练好的模型,位于{项目路径}/server/models/mnist
文件夹内。
我们使用MNIST数据集来训练模型。MNIST是一个知名的手写数字识别的数据集。对很多机器学习的初学者而言,这很可能是他们接触到的第一个数据集。这个数据集中包含60000张训练图片以及10000张测试图片,每张图片都是一个28×28像素的手写数字图片。如下图所示:
mnist.png
MNIST用一个28×28的矩阵来代表这样的一张数字图片,矩阵内的每个元素表示对应点位置的灰度,在0~255之间。
下载数据:
事实上,你可以跳过下载数据这一步而直接开始训练,因为在训练函数中会自动下载数据,但鉴于国内糟糕的网络环境,我还是建议你先把数据手动下载下来。我会优先从本地读取数据。
下载地址:mnist.npz
下载完成后保存在路径{项目路径}/server/datasets/mnist.npz
的位置。npz是numpy的一种数据压缩格式。文件大小大概11m。然后我们用load_data
函数载入数据:
import numpy as np
from tensorflow.keras import layers, datasets
def load_data(path):
try:
with np.load(path) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
x_train, x_test = x_train/255.0, x_test/255.0
return (x_train, y_train), (x_test, y_test)
except FileNotFoundError:
return datasets.mnist.load_data()
其中,x_train是一个60000×28×28的3维向量,代表60000张图片;y_train是长度60000的向量,每一项代表对应图片的实际数字,是一个0~9的整数。x_test,y_test是测试集上的对应数据,测试集大小为10000。注意x_train, x_test = x_train/255.0, x_test/255.0
这一步是把每个灰度数字转换为一个0~1之间的小数。
训练模型:
我们使用tensorflow.keras的接口来实现一个简单的卷积神经网络(Convolutional Neural Network, CNN)模型。它包含了一个卷积层,一个池化层,和两个全连接层。我不会这里解释全部概念——对新手来说,它们过于令人困惑而费解。而且你也不需要在这里理解它。如果你真的很想从直观上把握它的话,你可以试试这篇博客:An Intuitive Explanation of Convolutional Neural Networks。它有点长,但为此花一些时间依然是值得的。
在server/train.py文件下可以看到训练函数的代码:
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, datasets
import tensorflowjs as tfjs
def train_modle(data):
(x_train, y_train), (x_test, y_test) = data
model = Sequential([
layers.Reshape((28, 28, 1), input_shape=(28, 28)),
layers.Conv2D(16, (5, 5), padding='valid', input_shape=(28, 28, 1), activation='relu'),
layers.MaxPooling2D(pool_size=2),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=64)
model.evaluate(x_test, y_test)
tfjs.converters.save_keras_model(model, model_path)
从上至下,这个训练函数做的:
[ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]
,即除第3位为1外其他的都是0;而我们的预测结果可能是[0.2, 0.3, 0.1,...]
。我们这里使用交叉熵算法来评估两种概率分布间的差别。训练的目的就是使得这样的损失函数的值尽量接近0。撇开数学上的概念理解不谈,一般初学者在训练过程中最容易让人弄错的地方是数据的格式(shape)。
运行文件,开始训练
python server/train.py
如果你的环境配置正确,你应该会看到这样的输出:
60000/60000 [==============================] - 11s 185us/sample - loss: 0.1896 - acc: 0.9453
Epoch 2/5
60000/60000 [==============================] - 14s 225us/sample - loss: 0.0678 - acc: 0.9791
Epoch 3/5
60000/60000 [==============================] - 13s 221us/sample - loss: 0.0504 - acc: 0.9840
Epoch 4/5
60000/60000 [==============================] - 14s 233us/sample - loss: 0.0377 - acc: 0.9881
Epoch 5/5
60000/60000 [==============================] - 14s 231us/sample - loss: 0.0301 - acc: 0.9900
10000/10000 [==============================] - 1s 93us/sample - loss: 0.0360 - acc: 0.9879
在我的i5 cpu电脑上,整个训练过程大约耗时不到1分钟。
这个输出的结果显示了每一个epoch的耗时、损失函数的值和准确率。最后一行是在测试集上的结果。可以看到,我们的训练结果在测试集上有 98.79% 的准确率。同时,在{项目路径}/server/models/mnist
内的文件也会被覆盖更新。你可以调整模型的结构和条件,多试几次,来评估不同条件下的训练结果。
在{项目路径}/server/models/mnist
下有两个文件,一个很小的model.json文件和另一个大小约1m以上的.bin文件。model.json文件可以直接打开,里面包括了模型的一些总体信息,如模型的结构和参数文件的位置,也就是.bin文件,这个文件记录了这个模型训练出来的所有参数。另外,如果你已经改动过模型的结构或者其他条件重新训练,那么这样的参数文件可能不止一个。
我们已经训练好了模型,但这个模型文件是不能直接被浏览器载入使用的,因为现代浏览器一般都会阻止js直接读取本地文件内容。并且在设计上,这个模型文件也应该是保存在服务端而不是客户端。
我们需要做的,是启动一个服务,并使得这个模型成为这个服务的静态资源,这样js就可以通过请求拉取文件内容。
在server目录下的main.py文件:
from flask import Flask
from flask_cors import CORS
app = Flask(__name__,
static_url_path='/models',
static_folder='models')
cors = CORS(app)
@app.route("/")
def hello():
return "Hello World!"
if __name__ == '__main__':
app.run(debug=True)
这是一个非常简单的flask应用代码。在这个应用中,我们把url路径/models
映射到了文件目录models
(相对于本文件),外界通过{host}/models
就能访问到models
内的文件。
在项目的根目录下,用命令行启动这个文件:
python server/main.py
如果一切正常的话,你应该会看到这样的输出
* Serving Flask app "main" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment
Use a production WSGI server instead.
* Debug mode: on
* Restarting with stat
* Debugger is active!
* Debugger PIN: 267-971-636
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
现在,你可以打开 http://127.0.0.1:5000/ 或者 http://localhost:5000/,如果你在屏幕上看到了“Hello World!”,就说明服务已经启动成功了。ctrl+c可以退出服务。
此时如果你打开http://localhost:5000/models/mnist/model.json就可以看到我们之前训练出来的模型的json文件。
另外,注意在代码中,我们还加了一句cors = CORS(app)
,这是为了让这个服务接受跨域请求。本文在这里不会展开讨论这个问题,简单地说:如果在js脚本中试图请求拉取的后端资源的协议或域名与js本身的不一致,那么浏览器会阻止这个请求——这是一种安全保护策略,除非你加了这行代码让后端资源接受跨域。
我们在html头引入这几个库:
最后一个model.js是本地的js文件,我们会把模型的导入和数据预测函数都封装在这里。
在model.js文件里,导入模型:
const MODEL_URL = 'http://localhost:5000/models/mnist/model.json'
var loadModel = (async function() {
window.model = await tf.loadLayersModel(MODEL_URL);
console.log('load model')
return model;
})
loadModel();
MODEL_URL
就是模型的model.json文件的url地址。用tf.loadLayersModel函数来载入模型并绑定在window上。当你在浏览器控制台里看到 'load model',模型就载入成功了。注意tf.loadLayersModel是异步函数,它返回的是一个Promise对象,你需要用await
或者.then()
的回调式方法来获取载入的模型对象。
在html里,添加一个id="canvas"
的canvas和两个按钮,一个用于识别,另一个用于清空canvas。
在html内内加上:
var fabric_canvas = new fabric.Canvas('canvas', {backgroundColor: "#000000"});
fabric_canvas.renderTop();
fabric_canvas.isDrawingMode = true;
fabric_canvas.freeDrawingBrush.width = 12;
fabric_canvas.freeDrawingBrush.color = "#ffffff";
var recognize = async function() {
var results = await predict('canvas');
console.log(results);
}
var clear_canvas = function() {
fabric_canvas.clear();
}
我们使用fabric.js来构造可以任意涂抹的画图。这并不是必要的,只是可以少写一点代码。
在识别图片recognize
函数内,我们调用了一个predict函数,并传入了canvas的id。我们希望这个函数返回的结果就是预测结果。
const width = 28;
const height = 28;
var predict = async function(id) {
var model = window.model;
var canvas = document.getElementById(id);
var example = this.load_img(canvas);
var prediction = await model.predict(example).data();
var results = Array.from(prediction);
return results
}
var load_img = function(img) {
var tensor = tf.browser.fromPixels(img)
.resizeNearestNeighbor([width, height])
.mean(2)
.expandDims()
.toFloat()
.div(255.0)
return tensor;
};
predict函数的逻辑也相当直接了当:
.data
方法中取出预测结果。这个结果默认是Float32Array类型,可以转换为Array。与之前训练模型类似,最麻烦的地方是数据格式shape的处理,我们在load_img里有这么几步:
.mean
方法是求平均值的方法,用它我们把第3维的颜色转换为灰度。考虑到我们的图片是黑白的,它在3个颜色上应该是一样的,所以在这里我们也可以用.min
或.max
(最小值、最大值)来计算灰度。此时的大小是28×28。.expandDims
加上一维。此时大小为1×28×28。这里可以用.reshape([1, 28, 28])
来达到同样的效果。toFloat
,把tensor的元素转换为Float类型。.div(255.0)
。现在我们在canvas上写数据,再点击recognize按钮,就能在浏览器的控制台里看到预测的结果:
Array(10) [ 2.229090443993517e-15, 1.264737121454973e-12, 6.231850036009234e-10, 0.9999980926513672, 7.358067470207216e-14, 7.870837634982308e-7, 3.1836545118929527e-13, 6.341550395916329e-9, 8.096231454146618e-7, 1.0121870008816813e-10 ]
这个长度为10的数组表示模型预测canvas上的图片是0~9之间每个数字的概率。在我的项目里我还加了一个直方图表来表示这个数据,本文略过此处。
一些其他的值得注意的地方:
所有的代码都在这里:digits-recognition-tfjs。作者十分感谢这篇博客:Recognizing Digits using TensorFlow.js in Google Chrome,它对本文启发很大。
如果你觉得这篇文章有帮助的话,记得赞赏。