嵌入机器学习的微信小程序教程(五)——模型保存与小程序加载

模型保存与小程序加载

    • 机器学习模型的保存
    • 使用tensorflowjs转换模型
    • 小程序加载机器学习模型
      • 上传模型文件
      • 创建云函数
        • 下载依赖
        • 导入依赖
        • 载入模型

由于微信小程序对代码包有大小限制,如果选择嵌入模型,我们需要将模型转换为TensorFlow.js可用的 web 格式模型。然后将模型存入云平台,通过云函数加载模型,实现预测分类识别等。
当然将机器学习封装成api等方法可以,这里我们展示模型嵌入的方法。

机器学习模型的保存

Model对象提供了几种方法。

  • save() 保存了模型结构,模型参数和优化器参数。配合load_model() 使用
  • save_wights() 仅保存了模型权重,后续并不能仅加载权重进行训练和测试。配合load_weights() 使用。
  • to_json() 仅保存了模型结构。配合model_from_json() 使用

前两种的存储格式为hdf5,最后一种将模型存储为json文件。

from keras.models import load_model

model.save('my_model.h5')
new_model = load_model('my_model.h5')

model.load_weights(checkpoint_path)
model = model.model_from_json(json_string)

SavedModel 则是保存 TensorFlow 模型的默认格式。

import tensorflow as tf
  
tf.saved_model.save(pretrained_model, "/tmp/")
loaded = tf.saved_model.load("/tmp/")

使用tensorflowjs转换模型

我们需要将python构建的模型载入到以JavaScript为主的小程序中,就要借助tensorflowjs中的tfjs-converter进行模型转换。
在python中可以写为

import tensorflowjs as tfjs

tfjs.converters.save_keras_model(tojs, 'model/')

也可以通过运行一下命令

$ tensorflowjs_converter --input_format=keras /tmp/model.h5 /tmp/tfjs_model

将路径为 /tmp/model.h5 的模型转换并输出 model.json 文件及其二进制权重文件到目录 tmp/tfjs_model/ 中。
转换脚本会产生两种文件:

  • model.json (数据流图和权重清单)
  • group1-shard*of* (二进制权重文件)

更多的细节和不同类型模型转换可以查询tfjs-converter的自述文件

小程序加载机器学习模型

上传模型文件

打开云开发控制台,将转换好的模型文件上传到云存储。并将模型文件放到同一个文件夹中。
嵌入机器学习的微信小程序教程(五)——模型保存与小程序加载_第1张图片
嵌入机器学习的微信小程序教程(五)——模型保存与小程序加载_第2张图片
包括记录模型结构的json文件,和权重bin文件。

创建云函数

在小程序工程文件目录栏,右键单击cloudfunctions目录,新建Node.js云函数

下载依赖

右键单击新建的云函数,选择“在终端打开”,为云函数安装所需要依赖:
node.js基础依赖,wx-server-sdk,以及tfjs-node

npm安装命令分别是:

npm install
npm install wx-server-sdk@latest @tensorflow/tfjs-node fetch-wechat **加粗样式**

安装wx-server-sdk依赖才能进行云函数的本地调试。
可以使用镜像地址加快下载速度

npm install -g cnpm --registry=https://registry.npm.taobao.org
//cnpm代替npm
cnpm install @tensorflow/tfjs-node --save

导入依赖

const cloud = require('wx-server-sdk')
var fetchWechat = require('fetch-wechat')
const tf = require('@tensorflow/tfjs-node')

在云函数的初始化init()中补全云开发环境

cloud.init({
  env: 'wxnlp-tqlzu',
  traceUser: true,
})

载入模型

const model = await tf.loadGraphModel(model_path)
const model = await tf.loadLayersModel(model_path);

model_path为模型在云平台的存储路径,这里我们用用模型的json文件的路径。
因为加载json 文件后,函数将请求对应的json 文件引用的.bin文件。相应的.bin文件需要和json 文件在同一个文件夹中。这个工具依赖于fetch方法

你可能感兴趣的:(微信小程序,tensorflowjs)