用tensorflow.js实现浏览器内的手写数字识别

简介

Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预测放在前端,置于浏览器内。

本文会用一个简单的demo介绍如何从零开始训练一个tensorflow模型,并在浏览器内实现手写数字识别,最终效果大约如下:


用tensorflow.js实现浏览器内的手写数字识别_第1张图片
手写数字识别示例

本文会假设你有基本的python和JavaScript的知识。项目的完整代码参考github。

准备

项目代码的目录结构如下:


用tensorflow.js实现浏览器内的手写数字识别_第2张图片
项目目录结构

整个结构大概分成server和web两个部分,分别是服务端和浏览器端的代码。

我们的流程大概如下:

  1. 下载训练数据集,用python的tensorflow训练模型,并保存模型文件。
  2. 使用python的flask启动服务,使模型文件可以作为本地服务的静态文件被访问。
  3. 在网页html内,用canvas创建一个可以随意涂抹的画布,并能够获取画布上的像素信息。
  4. 在JavaScript脚本内导入tf.js,载入训练模型,通过模型计算画布上的信息的预测结果,并显示在图表上。

我们需要的所有依赖如下:

python:

建议使用3.5以上的版本。我不能保证在<3.5的版本中它是否能正常工作。Tensorflow的兼容性问题一向令人头疼。注意在mac和linux上默认的python是python2。

  • numpy —— 一个知名的python数学计算库,在矩阵和数组运算方面非常强大
  • tensorflow —— 机器学习库,直接用pip安装的是cpu版本。如果你的pc有一个足够好的独立显卡,可以试试tensorflow-gpu。它可以使训练的速度更快。但tensorflow-gpu的配置方法比较复杂。我们的模型比较简单,即使用cpu训练也不会耗时太久。
  • tensorflowjs —— 用于导出并保存可以被浏览器使用的模型文件
  • flask —— 一个轻量级的python网络服务框架
  • flask-cors —— 用于支持flask跨域请求的一个库

这个demo内已经包含一个已经训练好的模型,所以你如果并不想自己再训练一次,可以不安装tensorflow和tensorflowjs。所有这些依赖都可以通过pip安装。

JavaScript:

你不需要特别安装任何东西,因为我们的库都是通过链接导入的。

  • tf.js —— 它就是本文要介绍的,尽管只会涉及它的极小的一点。
  • fabric.js —— 可选,用于比较方便地构造画布。
  • Chart.js —— 可选,只是用来画出下边的图表的。你也可以不要它,如果你对这种可视化的结果不感兴趣。
浏览器:

反正在chrome浏览器里是能跑起来的……

训练

项目文件里面已经包含了一个训练好的模型,位于{项目路径}/server/models/mnist文件夹内。

我们使用MNIST数据集来训练模型。MNIST是一个知名的手写数字识别的数据集。对很多机器学习的初学者而言,这很可能是他们接触到的第一个数据集。这个数据集中包含60000张训练图片以及10000张测试图片,每张图片都是一个28×28像素的手写数字图片。如下图所示:

用tensorflow.js实现浏览器内的手写数字识别_第3张图片
mnist.png

MNIST用一个28×28的矩阵来代表这样的一张数字图片,矩阵内的每个元素表示对应点位置的灰度,在0~255之间。

下载数据:

事实上,你可以跳过下载数据这一步而直接开始训练,因为在训练函数中会自动下载数据,但鉴于国内糟糕的网络环境,我还是建议你先把数据手动下载下来。我会优先从本地读取数据。

下载地址:mnist.npz
下载完成后保存在路径{项目路径}/server/datasets/mnist.npz的位置。npz是numpy的一种数据压缩格式。文件大小大概11m。然后我们用load_data函数载入数据:

import numpy as np
from tensorflow.keras import layers, datasets

def load_data(path):
    try:
        with np.load(path) as f:
            x_train, y_train = f['x_train'], f['y_train']
            x_test, y_test = f['x_test'], f['y_test']
            x_train, x_test = x_train/255.0, x_test/255.0
            return (x_train, y_train), (x_test, y_test)
    except FileNotFoundError:
        return datasets.mnist.load_data()

其中,x_train是一个60000×28×28的3维向量,代表60000张图片;y_train是长度60000的向量,每一项代表对应图片的实际数字,是一个0~9的整数。x_test,y_test是测试集上的对应数据,测试集大小为10000。注意x_train, x_test = x_train/255.0, x_test/255.0这一步是把每个灰度数字转换为一个0~1之间的小数。

训练模型:

我们使用tensorflow.keras的接口来实现一个简单的卷积神经网络(Convolutional Neural Network, CNN)模型。它包含了一个卷积层,一个池化层,和两个全连接层。我不会这里解释全部概念——对新手来说,它们过于令人困惑而费解。而且你也不需要在这里理解它。如果你真的很想从直观上把握它的话,你可以试试这篇博客:An Intuitive Explanation of Convolutional Neural Networks。它有点长,但为此花一些时间依然是值得的。

在server/train.py文件下可以看到训练函数的代码:

from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, datasets
import tensorflowjs as tfjs

def train_modle(data):
    (x_train, y_train), (x_test, y_test) = data
    model = Sequential([
        layers.Reshape((28, 28, 1), input_shape=(28, 28)),
        layers.Conv2D(16, (5, 5), padding='valid', input_shape=(28, 28, 1), activation='relu'),
        layers.MaxPooling2D(pool_size=2),
        layers.Dropout(0.2),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.fit(x_train, y_train, epochs=5, batch_size=64)
    model.evaluate(x_test, y_test)
    tfjs.converters.save_keras_model(model, model_path)

从上至下,这个训练函数做的:

  • 读取从npz中或自动下载的训练数据。
  • 构造模型的层序列,tensorflow.keras是以layer这种对象组织计算过程的。每一层输出是下一层的输入,最后的输出就是模型的输出。它每层依次是:
    • Reshape。注意我们的输入的数据对应的是图片的每个像素的灰度,是没有‘深度’的。而卷积层要求的输入必须是有‘深度’的。所以我们首先为数据额外地加一个‘深度’为1的第三个维度。
    • 卷积层。这一层的作用是提取一个图片的每一个点周围的‘局部特征’,并传递给下一层。我们需要16种特征,对每一种特征,我们用一个5×5大小的矩阵,‘扫描’图片,并据此计算出一个值。所以这一步,是把每个点都映射到一个长16的向量,来代表这个点的16种不同的局部特征。
    • 池化层。这一步是为了降低数据的大小。在所有的相邻的2×2的的范围内,我们只保留其中的最大值。
    • Dropout。在训练过程中,每次更新参数时,随机地把一部分输入节点忽略掉。这是一种防止过拟合的简单技巧。它只会应用于训练时,不会用在预测上。
    • Flatten。输入展平成一个一维数组。如果你的下一层是全连接层,那么这一步是必要的(除非你想把输出格式搞得一团糟)。
    • Dense。大小为128的全连接层。上一层的所有点都与这层的所有点相连。
    • Dense。大小为10的最末端的全连接层,它的输出就是模型的预测结果,对应一张图片是每个数字的概率。
  • 损失函数是用于估计模型预测结果和正确结果的偏差的函数。比如实际的数字为2,那么我们期待的结果应该是[ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ],即除第3位为1外其他的都是0;而我们的预测结果可能是[0.2, 0.3, 0.1,...]。我们这里使用交叉熵算法来评估两种概率分布间的差别。训练的目的就是使得这样的损失函数的值尽量接近0。
  • 优化函数决定了在预测结果和正确结果的特定偏差下,应该如何更新参数。这里我们使用adam优化器。
  • fit。使用训练集来训练。我们每次在大小60000数据中取出64个作为一批,计算损失函数并优化参数。在整个数据集上,重复5次。
  • evaluate。使用测试集来评估训练结果。只计算损失函数,不做参数优化。
  • 最后一步,是把模型的训练结果保存成文件,在预测时可以调用。

撇开数学上的概念理解不谈,一般初学者在训练过程中最容易让人弄错的地方是数据的格式(shape)。

运行文件,开始训练

python server/train.py

如果你的环境配置正确,你应该会看到这样的输出:

60000/60000 [==============================] - 11s 185us/sample - loss: 0.1896 - acc: 0.9453
Epoch 2/5
60000/60000 [==============================] - 14s 225us/sample - loss: 0.0678 - acc: 0.9791
Epoch 3/5
60000/60000 [==============================] - 13s 221us/sample - loss: 0.0504 - acc: 0.9840
Epoch 4/5
60000/60000 [==============================] - 14s 233us/sample - loss: 0.0377 - acc: 0.9881
Epoch 5/5
60000/60000 [==============================] - 14s 231us/sample - loss: 0.0301 - acc: 0.9900
10000/10000 [==============================] - 1s 93us/sample - loss: 0.0360 - acc: 0.9879

在我的i5 cpu电脑上,整个训练过程大约耗时不到1分钟。
这个输出的结果显示了每一个epoch的耗时、损失函数的值和准确率。最后一行是在测试集上的结果。可以看到,我们的训练结果在测试集上有 98.79% 的准确率。同时,在{项目路径}/server/models/mnist内的文件也会被覆盖更新。你可以调整模型的结构和条件,多试几次,来评估不同条件下的训练结果。

{项目路径}/server/models/mnist下有两个文件,一个很小的model.json文件和另一个大小约1m以上的.bin文件。model.json文件可以直接打开,里面包括了模型的一些总体信息,如模型的结构和参数文件的位置,也就是.bin文件,这个文件记录了这个模型训练出来的所有参数。另外,如果你已经改动过模型的结构或者其他条件重新训练,那么这样的参数文件可能不止一个。

服务

我们已经训练好了模型,但这个模型文件是不能直接被浏览器载入使用的,因为现代浏览器一般都会阻止js直接读取本地文件内容。并且在设计上,这个模型文件也应该是保存在服务端而不是客户端。
我们需要做的,是启动一个服务,并使得这个模型成为这个服务的静态资源,这样js就可以通过请求拉取文件内容。
在server目录下的main.py文件:

from flask import Flask
from flask_cors import CORS

app = Flask(__name__,
            static_url_path='/models', 
            static_folder='models')

cors = CORS(app) 

@app.route("/")
def hello():
    return "Hello World!"

if __name__ == '__main__':
    app.run(debug=True)

这是一个非常简单的flask应用代码。在这个应用中,我们把url路径/models映射到了文件目录models(相对于本文件),外界通过{host}/models就能访问到models内的文件。
在项目的根目录下,用命令行启动这个文件:

python server/main.py

如果一切正常的话,你应该会看到这样的输出

* Serving Flask app "main" (lazy loading)
* Environment: production
WARNING: This is a development server. Do not use it in a production deployment
Use a production WSGI server instead.
* Debug mode: on
* Restarting with stat
* Debugger is active!
* Debugger PIN: 267-971-636
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)

现在,你可以打开 http://127.0.0.1:5000/ 或者 http://localhost:5000/,如果你在屏幕上看到了“Hello World!”,就说明服务已经启动成功了。ctrl+c可以退出服务。
此时如果你打开http://localhost:5000/models/mnist/model.json就可以看到我们之前训练出来的模型的json文件。
另外,注意在代码中,我们还加了一句cors = CORS(app),这是为了让这个服务接受跨域请求。本文在这里不会展开讨论这个问题,简单地说:如果在js脚本中试图请求拉取的后端资源的协议或域名与js本身的不一致,那么浏览器会阻止这个请求——这是一种安全保护策略,除非你加了这行代码让后端资源接受跨域。

预测

我们在html头引入这几个库:

    
        
        
        
        
    

最后一个model.js是本地的js文件,我们会把模型的导入和数据预测函数都封装在这里。

在model.js文件里,导入模型:

const MODEL_URL = 'http://localhost:5000/models/mnist/model.json' 

var loadModel = (async function() {
    window.model = await tf.loadLayersModel(MODEL_URL);
    console.log('load model')
    return model;
})
loadModel();

MODEL_URL 就是模型的model.json文件的url地址。用tf.loadLayersModel函数来载入模型并绑定在window上。当你在浏览器控制台里看到 'load model',模型就载入成功了。注意tf.loadLayersModel是异步函数,它返回的是一个Promise对象,你需要用await或者.then()的回调式方法来获取载入的模型对象。

在html里,添加一个id="canvas"的canvas和两个按钮,一个用于识别,另一个用于清空canvas。

        

在html内

你可能感兴趣的:(用tensorflow.js实现浏览器内的手写数字识别)