原链接:https://medium.com/tensorflow/introducing-tensorflow-js-machine-learning-in-javascript-bf3eab376db
声明:博客中所有图片均来自于原博客。
Tensorflow.js的发布,真的是一件让程序猿兴奋的事情,Tensorflow.js是一个开源的库,我们可以使用JavaScript或者高level的API在浏览器中定义代码,训练代码,运行机器学习模型。如果你是一个JavaScript的开发者,但是在机器学习方便还入坑未深,Tensorflow.js是一个你可以开始你的机器学习之旅的东东。再或者,你是一个机器学习的开发者,但是你是个Tensorflow.js的菜鸡,你可以通过浏览器端的机器学习,学到很多js的东西。(完了,两个我都是半吊子,卒……)。在这篇博文中,我们会简要介绍一下Tensorflow.js,然后会给一些你可以学习的资源。
在浏览器上运行机器学习程序,打开了很多新世界的大门,比如交互式的机器学习!如果你看过赖在TensorFlow Developer Summit的,直播的话(YouTube上的视频),在Tensorflow.js的talk中,你会发现 @dsmilkov和@nsthorat训练好了一个模型,来使用人的行为控制游戏,这里面使用到了计算机视觉和webcam,这个游戏完全是在浏览器下面的,牛逼吧!你可以自己试着玩一玩,链接丢在下面,你也可以在这里找到源码。
使用摄像头来控制游戏的示例,游戏链接在这儿
还有另外一个游戏,叫做Emoji Scavenger Hunt,不过这个是在手机浏览器上的一个游戏
Emoji Scavenger Hunt是另外一个基于Tensorflow.js做的应用,你可以在手机上打开,也可以在这里找到源码
在浏览器上运行的ML,站在用户的角度来说,你无需安装任何库和驱动。只需要打开一个web页面,你的程序就可以跑起来了。除此之外,你的计算是通过GPU加速的,Tensorflow.js会自动调用WebGL,当你的电脑有GPU支持的时候,会加速你页面场景的运行。用户也可能是从移动设备上打开的web页面,如果用的是移动设备的话,有个好处是,你能够充分使用手机上的传感器来获取数据,比方说重力加速度。此外,所有的数据都存储在客户端,使得数据交互的延迟性比较低,同时,数据使用也比较私密。
如果你使用Tensorflow.js开发应用程序的话,你可能需要考虑三个步骤:
1、你可以导入一个已经训练好的模型,来作为预训练的模型接口。如果你之前离线训练了Tensorflow或者Keras的模型,你可以将这个模型转化为Tensorflow.js的格式,也可以将这个模型放到浏览器上,来作为浏览器应用程序的接口;
2、你也可以重新训练导入的模型。之前提到过的Pac-Man的demo,就是这么干的,你也可以用迁移学习的方法,用很少量的数据来再训练你的现有模型,迁移学习也是一种能够让你用很少的数据训练出一个精确模型的方法;
3、你也可以在浏览器端写一个模型。你可以使用Tensorflow.js来定义,训练,使用模型。Tensorflow.js提供了封装的很好的接口,如果你熟悉Keras的话,这些接口你用起来会觉得非常熟悉,非常容易上手。
看一下源码吧:
如果你愿意的话,你可以直接看示例和教程来开始你的Tensorflow.js之旅。这些会给你展示下,怎么用将python训练的模型导出来使用,还介绍了怎么完全用JavaScript来定义和训练模型。让我们来先看一段代码,这是一个简单定义了神经网络来对花进行分类的代码,这段代码很像Tensorflow.org里的示例。
import * as tf from ‘@tensorflow/tfjs’;
const model = tf.sequential();
model.add(tf.layers.dense({inputShape: [4], units: 100}));
model.add(tf.layers.dense({units: 4}));
model.compile({loss: ‘categoricalCrossentropy’, optimizer: ‘sgd’});
define-model.js
我们在JavaScript中使用的所有的API,都可以与Keras中的示例模型无缝对接(包括Dense,CNN,LSTM等等)。我们还能够使用下面类似的与Keras兼容的API来训练我们的模型。
await model.fit(
xData, yData, {
batchSize: batchSize,
epochs: epochs
});
train-model.js
这个模型就可以用来做预测了:
// Get measurements for a new flower to generate a prediction
// The first argument is the data, and the second is the shape.
const inputData = tf.tensor2d([[4.8, 3.0, 1.4, 0.1]], [1, 4]);
// Get the highest confidence prediction from our model
const result = model.predict(inputData);
const winner = irisClasses[result.argMax().dataSync()[0]];
// Display the winner
console.log(winner);
predict.js
Tensorflow.js也有一些比较底层的API(比如之前的deeplearn.js),也支持Eager execution。你可以在Tensorflow的开发者峰会上,看到更多的讨论。
总览Tensorflow.js的API。Tensorflow.js由WebGL进行支持,提供high-level的网络层API来定义模型,也提供一些底层的API来进行线性代数运算和自动的求导。Tensorflow.js支持Tensorflow保存的模型和Keras保存的模型。
这个问题问的很好!Tensorflow.js是一个机器学习的JavaScript工具生态系统,它继承于deeplearn.js,现在呢,deeplearn.js部分,已经成为了Tensorflow.js的核心部分。Tensorflow.js也包括一些网络层的API,这些更高级一些的API用来使用Core来构建机器学习模型,也有一些可以直接导入Tensorflow的SavedModels和Keras hdf5模型的API。更多关于该问题的解答,请详见FAQ。
学习更多关于Tensorflow.js的知识,请访问项目主页,你可以看到一些教程,还有一些示例。你还可以观看Tensorflow开发者峰会的视频(YouTube上的),还可以在follow Tensorflow的推特。
谢谢欣赏,很期待看到你用Tensorflow.js做事情。如果你喜欢的话,也可以粉一下Tensorflow.js团队的几个成员:@dsmilkov, @nsthorat, 和@sqcai。