节选自《深度学习TensorFlow.js:浏览器实战篇》第八章,已获授权。
在前面的章节,我们讨论了各种JavaScript概念和运行在浏览器上的各种深度学习框架。在本章中,我们将所有的知识付诸于实践,证明该技术的潜力。
注意,本书所有源代码放在https://github.com/backstopmedia/deep-learning-browser。你也能在https://reii-nakano.github.io/tfjs-rock-paper-scissors/访问Rock Paper Scissors游戏示例。也能在https://reiinakano.github.io/tfjs-lstm-text-generation/访问文本生成模型的示例。
本章的每个小节都代表一个完整的TensorFlow.js应用,每个应用都能先看一个在线演示。然后我们详细讲解项目中使用到的算法。最后给出程序运行的完整解释,洙行代码讨论TensorFlow.js的细节。
我们希望你能把这些小项目当作指引,用在你自己的深度学习模型和业务逻辑当中。
在本小节,我们使用TensorFlow.js在webcam上实现玩石头剪刀布游戏。在进行详细的解释之前,我们先去Github页面看看它是如何运行的。如果你前面玩过Google的Teachable Machine,那你会注意到这里的训练机制是相同的。为了教浏览器识别“石头”手势,点击摄像头打出“石头”手势(握紧拳头),然后点击“Train Rock”按钮获取截图。当你玩石头剪刀布游戏时,训练好的机器学习模型能够探测手势。为了训练的模型更稳定,你要确保浏览器获取到不同的手势。注意,你不需要使用手势去区分手头、剪刀和布。
即使你不训练模型,浏览器也会持续的扫描webcam并分类为石头、剪刀或者布。机器学习模型的尺寸小使得模型训练和分类预测都可以实时进行。一旦你训练好三种手势的模型,你就可以开始在浏览器上玩石头剪刀布游戏。
为了理解代码,我们需要掌握预测算法的细节。
手势识别算法重要的特征之一是尺寸小和推断速度快。如果浏览器需要下载100MB的神经网络权重,那么你的所有用户都会抱怨。另外,如果他需要十秒钟预测一个手势,那也很难实时预测。幸运地是,这些条件神经网络模型都满足。SqueezeNet模型是专门设计尽可能小,并达到可接受的图像识别。根据原始的论文,在ImageNet比赛中,SqueezeNet模型只需要0.5MB的存储空间即可达到AlexNet模型一样水平的准确度,这对我们的应用已经足够了。SqueezeNet模型对内存和处理能力有限制的环境非常实用,比如手机或者浏览器。
我们这里不仅仅依赖SqueezeNet模型,另外一个重要的特征是可以从很少的数据量中学习到特征。本小节的例子中,每个手势只需要大约50张图片即可达到可接受的预测效果。ImageNet中包含百万级的图片,每个类别中有几百张图片。那么相对于ImageNet,我们的模型如何用这么少的训练集达到很好的效果呢?答案是一种称之为迁移学习(transfer learning)的技术。
迁移学习更一般的理念是,考虑一个模型如何解决当前问题时,可以使用不同问题但相关的问题训练出来的知识。
迁移学习的一般流程很简单。首先,你在合适的海量数据集上训练你的神经网络模型。这些数据要和实际的数据集尽可能的相似,比如,图像识别的图片。对于图像识别任务来说,数据源一般是ImageNet。在训练完模型之后,你可以切出模型的最后几层(一般取一到两层),接着运行自己的图片。换句话讲,你会为每张图片获得一个中间layer的输出,而不是根据ImageNet的类别来对你的图片进行分类。这些输出是你自己的图片通过预训练的ImageNet网络模型抽取的特征。该网络能解析出输入图片中泛化的相关特征。我们的图片越接近于ImageNet图片集,其生成的特征效果越好。
做完上面的步骤后,我们可以使用抽取的特征来训练不同的分类器,可能是我们的类别。常用的方法是在特征抽取器后增加一个全联接神经网络,并进行模型训练,但是这时要冻结原始神经网络的参数,只更新新增加网络的权重。
迁移学习在收集领域数据非常困难的情况下是相当有意义的,比如医疗图像处理。
在我们的应用中,我们使用抽取的特征来训练一个K最近邻(K-Nearest Neighbor,KNN)分类器,而不是在预训练的ImageNet SqueezeNet模型基础上增加一个神经网络。K最近邻分类器是给定一个训练数据集,对新输入的样本,在训练数据集中找到与该样本最邻近的K个样本(K个邻居), 这K个样本的多数属于某个类,就把该新样本分类到这个类别。这只需要矩阵乘法就可以计算,在TensorFlow.js中只用单个张量操作。因为训练一个KNN分类器比训练神经网络模型要快得多(你需要做的只是将训练样本增加到矩阵)。对于学习少量数据集,我们在浏览器上可以进行实时模型训练。
下面做一个简单的总结,我们的模型如下:
使用预训练的ImageNet SqueezeNet模型,我们用它的最后两层layer作为webcam图片的特征抽取器
我们使用抽取的特征作为K最近邻分类器的输入,训练为三个分类:石头、剪刀和布
为了对图片进行推断,我们在SqueezeNet模型上运行,将抽取的特征输入新训练的KNN分类器探测手势。就这么简单。
我们使用Yarn按照项目的所有依赖。对于从来没用过Yarn的用户,它是Javascript广泛使用的依赖管理器。虽然使用基础相当直观,如果你想理解如何使用,可以查看Yarn的文档。
定义应用的依赖的主要文件是package.json,存放在代码仓库root下。我们定义项目的元数据,比如,name、 version和license。需要注意的部分是dependencies项,它罗列出项目的依赖,使得其它.js文件很容易的引用这些依赖库。我们项目重要的两个依赖是:deeplearn 0.5.0和deeplearn-knn-image- classifier 0.3.0,其中deeplearn包含TensorFlow.js,deeplearn-knn-image- classifier包含所有前面小节我们讨论的模型的代码,封装成一个单独的、易用的NPM包。
如果想添加一个NPM包,只需简单地在仓库root下运行yarn add <package-name>。该命令会自动下载这个NPM包以及其依赖,并更新package.json和yarn.lock文件。
package.json也包含一些开发应用的脚本。第一个重要的脚本是prep,它能从仓库的root调起yarn prep。当你克隆代码仓库Yarn下载项目所有依赖时,该脚本会第一次运行。也会同时创建dist文件夹,它会存储构建过程创建的文件。另一个重要的脚本是调用yarn start,它会在localhost:9966开启开发服务,监控你的源代码变化并自动更新你的应用。这是一个高效的开发循环。最后,yarn build和yarn deploy使用browserify和uglify-js编译你的各种.js文件,生成单个较小的、生产环境使用的.js文件(为真实环境发布应用做准备)。
我们开始检查应用的源代码。因为本书是基于浏览器的深度学习,所以我们只关注应用中相应的部分。但是无需担忧,深度学习无关的代码尽可能用原生的JavaScript,没有使用像Vue.js或者React的外部框架。如果你计划在应用中使用这些框架,你也可以很容易在TensorFlow.js代码中使用这些外部框架。
让我们看一下deeplearn-knn-image-classifier包中的KNNImageClassifier类,该类创建神经网络,下载预训练模型权重,为每个训练图片调整KNN模型,并对新图片进行推断。
在项目中root目录下的main.js文件,我们定义一个Main类,并在浏览器窗口加载时实例化。Main类的构造器会初始化应用的所有变量的代码。在构造器函数constructor的末尾处,我们看到下面的代码:
// Instantiate the knn model
this.knn = new KNNImageClassifier(NUM_CLASSES, TOPK);
// Load knn model
this.knn.load().then(() => this.start());
第一行代码创建一个KNNImageClassifier对象,并分配给this.knn。KNNImageClassifier的构造器需传入两个参数:numClasses和k。numClasses定义模型期望分类的类别数。在本例中,numClasses为3(每种手势一个类别)。k是KNN算法模型的参数,它定义模型决定一个样本分类时所要考虑的邻居数。
第二行代码调用KNNImageClassifier的load函数。load函数用来下载预训练的SqueezeNet模型的权重。你将注意到这里then函数的使用,这说明load函数是一个异步函数,其返回一个Promise对象。当SqueezeNet模型的权重下载完成时,Promise对象决定执行。这时我们将调用this.start()开始TensorFlow.js迭代训练过程。
Main.start()函数定义如下:
start() {
this.video.play();
this.timer = requestAnimationFrame(() => this.animate());
}
上面的代码做了两件事:this.video.play()开启webcam流。this.animate()调用 TensorFlow.js迭代训练的第一次迭代。你会注意到,我们用requestAnimationFrame封装this.animate()调用。requestAnimationFrame是一个异步函数,当浏览器打开时requestAnimationFrame函数会调用传入的函数。这能确保在迭代训练时同步更新浏览器的视口。你也会注意到this.animate()末尾的一行代码:
this.timer = requestAnimationFrame(() => this.animate());
所以,在this.animate()的单个迭代的最后,我们会等待浏览器刷新它的视口,然后调用迭代训练的下一个迭代。这个常规的模式会确保,在更多的张量排队等待GPU处理时,浏览器得到合适的渲染。如果没有该模式浏览器会挂住,渲染web页面不可用。
你也应该注意到了,我们将requestAnimationFrame的返回结果分配给this.timer变量。虽然在本例中我们并没有使用该变量,但是它会基于某些事件给我们停止/暂停迭代训练的选项。stop函数会暂停我们的迭代训练,代码如下:
stop(){
this.video.pause();
cancelAnimationFrame(this.timer);
}
下面让我们看一下迭代训练中每个迭代都做了什么。在animate()函数中,我们从下面这行代码讲起:
const image = dl.fromPixels(this.video);
fromPixels函数的功能是把浏览器图片转化成一个3D张量,该张量包含图片的像素亮度。fromPixels函数可以从 ImageData、HTMLImageElement、HTMLCanvasElement或者HTMLVideoElement抓取图片。在本例子中,我们传入webcam的HTMLVideoElement。fromPixels函数把webcam的当前显示图片转换成一个3D张量,以供给其它TF.js函数使用。
接着,我们看下面一段代码:
// Train class if one of the buttons is held down
if (this.training != -1) {
// Add current image to classifier
this.knn.addImage(image, this.training);
}
当训练按钮被点击时,上面的代码会检测是否在训练三种手势其中的一种,并增加图片到KNN模型。这步很容易用KNNImageClassifier实例的addImage函数实现。addImage函数传入新训练图片的3D张量和相应的分类。KNNImageClassifier在SqueezeNet模型基础上处理图片,输入特征抽取的结果,并将其增加到训练样本的数组。
迭代训练的下一个代码块如下:
const exampleCount = this.knn.getClassExampleCount();
if (Math.max(...exampleCount) > 0) {
this.knn.predictClass(image)
.then((res)=>{
// Do something with our model's prediction `res`
})
.then(()=> image.dispose());
}else{
image.dispose();
}
我们调用this.knn.getClassExampleCount()获取每个分类的图片数目。如果我们对至少一张图片进行了模型训练,那么我们会继续并使用模型进行图片预测。
为了预测一张图片的分类,我们传入一个3D张量到KNN图片分类器的predictClass函数。predictClass函数是一个异步函数,提供的图片进行推断,并返回一个Promise。Promise会决定推断的结果。predictClass函数紧跟的.then函数调用会定义一个函数,当推断完成会执行该函数。在本例子中,我们使用推断的结果更新UI上相应的变量、文本和图片。因为.then函数也会在传入的函数完成时返回一个Promise,所以我们用另外一个.then函数链式地调用函数。这时我们调用图片的3D张量对象的dispose()方法,它会释放指定部分张量的GPU的内存。如果不这么操作,随着迭代训练每次迭代都会持续地分配图片张量对象,我们会出现内存泄漏 。
最后注意,如果我们没有对单个类别进行训练,那么同时也会忽略对当前图片的推断,并用image.dispose()丢弃图片张量对象。
下面总结一下,TensorFlow.js 的迭代训练过程如下:
从摄像头抓取一张图片,并使用tf.fromPixels 函数将其转换成一个3D张量
检查我们当前是否在处理某个手势。如果是,则用KNNImageClassifier.addImage函数增加图片和相应的类别到我们的模型
检查我们的模型当前是否在训练最近的一个手势。如果是,使用KNNImageClassifier.predictClass函数推断当前处理的图片。基于这个结果去更新类别的变量和UI元素
使用张量对象的.dispose()方法丢弃图片
使用requestAnimationFrame,调用this.animate() 运行迭代训练的下一次迭代。在我们不断地迭代之前,requestAnimationFrame确保在浏览器时间重绘视口
回到第一步
到目前为止我们还剩下两个函数没讨论:startGame和resolveGame函数。这两个函数包括在浏览器上运行石头剪刀布游戏的有效代码。它们处理游戏的流程,监控TensorFlow.js迭代过程中设置的中间变量,检查用户当前在摄像头做的哪种手势,并相应的更新UI。然而,这里我们并不去深究两个函数,因为它们并不包含TensorFlow.js相关的代码。理解这两个函数留作课外练习。