前面是讲JavaScript如何搭建模型并运行的,但是实际情况是大家一般不会直接用JavaScript进行构建,而是先用Python进行本地的测试,而且就目前的情况来说,Python构建的PC端应用程序明显更多。
但是TensorFlow.js早就想到了这个问题,因此他们构建了一些工具,能够支持开发者基于本地Python进行训练后移植到Web端。
这里举了两个最常用领域的例子。
其实就是一些脏话的检测,英文版的,模型结构和参数已经训练好放在了网上。
直接引入训练好的模型
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/toxicity">script>
然后直接调用模型,可以对七种不同形式的脏话进行检测。
const threshold = 0.9;
toxicity.load(threshold).then(model => {
const sentences = ['you suck'];
model.classify(sentences).then(predictions => {
console.log(predictions);
for(i=0; i<7; i++){
if(predictions[i].results[0].match){
console.log(predictions[i].label +
" was found with probability of " +
predictions[i].results[0].probabilities[1]);
}
}
});
});
mobilenet也是别人已经训练好的,放入一张图片,然后输出物品属于哪一类的概率,输出概率最高的三项。
通过网页右边的Network我们可以看到,其实引入的模型是下载了model.json这个文件,将模型的结构进行构建,然后将model.json关联的五个参数文件都分别下载了下来。上面的不规范语言检测其实也是一样的。
这一部分其实很实用,因为一般都不会直接就在浏览器上开工,而是在本地先将模型跑通后移植到浏览器上。
首先先安装tensorflow.js相关的python库
pip install tensorflowjs
import numpy as np
import tensorflow as tf
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=1, input_shape=[1])
])
model.compile(optimizer='sgd', loss='mean_squared_error')
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
model.fit(xs, ys, epochs=500)
import time
saved_model_path = "./{}.h5".format(int(time.time()))
model.save(saved_model_path)
这里就是普通的训练过程了,因为是线性拟合所以代码比较简单了,其他网络其实也是一样的。
最重要的是要将模型保存成h5文件。
接着就可以直接使用工具将模型转化成浏览器可以读取的格式了。
一般来说,会生成一个model.json文件,用来构建模型的结构以及表明后面的参数文件;参数文件有一个或多个,主要看网络是不是很复杂,如果很复杂的话将会生成多个参数文件。因为这里的例子网络很简单,所有也就只有一个参数文件。
在命令行或者jupyter处执行
tensorflowjs_converter --input_format=keras {saved_model_path} ./
浏览器直接加载model.json文件即可调用
async function run(){
const MODEL_URL = 'http://127.0.0.1:8887/model.json';
const model = await tf.loadLayersModel(MODEL_URL);
console.log(model.summary());
const input = tf.tensor2d([10.0], [1,1]);
const result = model.predict(input);
alert(result)
}
run();