TensorFlow.js课程笔记(四)

前言

终于到了最后一章,这里就是相当于一个大作业,设计一个TensorFlow.js程序能够识别《石头剪刀布》中的各个手势。升级版就是添加了Spock和Lizard另外两个手势,有兴趣的可以看一下维基百科。

TensorFlow.js课程笔记(四)_第1张图片

这里主要用到了迁移学习的知识,将mobilenet识别图像特征的能力迁移到我们的应用中,然后根据实际的需求进行改写。

石头剪刀布

这个demo是调用电脑摄像头,你把手势一张张拍下来自己打标签,然后通过训练使其能够正确识别剪刀石头布。因此mobilenet很大,要小飞机下半天这里就没法加载了,界面如下(只能这么写出来了):
TensorFlow.js课程笔记(四)_第2张图片
点击三个按钮制作数据集,我左右手各收集了两张,注意背景不要太乱,然后点击训练就行。最后开始识别的话就会实时读取摄像头数据进行判断。

迁移学习实现图像特征提取

async function loadMobilenet() {
  const mobilenet = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json');
  const layer = mobilenet.getLayer('conv_pw_13_relu');
  return tf.model({inputs: mobilenet.inputs, outputs: layer.output});
}

这里就使用mobilenet进行迁移学习,我们不需要mobilenet的分类能力,只需要其提取特征的能力即可,因此将conv_pw_13_relu这一层作为输出,来提取图像的特征。

利用特征构建自己的分类器

获取了上面的图像特征之后,我们便可以搭建自己的分类器了

async function train() {
  dataset.ys = null;
  dataset.encodeLabels(3);
  model = tf.sequential({
    layers: [
      tf.layers.flatten({inputShape: mobilenet.outputs[0].shape.slice(1)}),
      tf.layers.dense({ units: 100, activation: 'relu'}),
      tf.layers.dense({ units: 3, activation: 'softmax'})
    ]
  });
  const optimizer = tf.train.adam(0.0001);
  model.compile({optimizer: optimizer, loss: 'categoricalCrossentropy'});
  let loss = 0;
  model.fit(dataset.xs, dataset.ys, {
    epochs: 10,
    callbacks: {
      onBatchEnd: async (batch, logs) => {
        loss = logs.loss.toFixed(5);
        console.log('LOSS: ' + loss);
        }
      }
   });
}

这里是一个三层的神经网络,最后一层输出层是softmax对三个手势进行概率的输出。

要说明一下这里的dataset.xs。ys就是给手势打的标签,用onehot进行编码,xs是已经通过mobilenet进行特征提取后的输出,所以才会是上面的inputShape。
构建数据集时的关键代码如下,可以看到xs是mobilenet进行特征提取后的输出:

dataset.addExample(mobilenet.predict(img), label);

将训练集喂入后就可以将我们自己的分类器训练出来了。

完成构建进行预测

async function predict() {
  while (isPredicting) {
    const predictedClass = tf.tidy(() => {
      const img = webcam.capture();
      const activation = mobilenet.predict(img);
      const predictions = model.predict(activation);
      return predictions.as1D().argMax();
    });
    const classId = (await predictedClass.data())[0];
    var predictionText = "";
    switch(classId){
		case 0:
			predictionText = "I see Rock";
			break;
		case 1:
			predictionText = "I see Paper";
			break;
		case 2:
			predictionText = "I see Scissors";
			break;
	}
	document.getElementById("prediction").innerText = predictionText;
			
    
    predictedClass.dispose();
    await tf.nextFrame();
  }
}

这里就可以知道整个预测流程,抓取相机的一帧,然后先喂入mobilenet获取特征,然后喂入分类器得到三个手势的结果,输出最大可能性的那个。

你可能感兴趣的:(js,machine,learning)