【tensorflow.js学习笔记(3)】迁移学习——用web摄像头玩吃豆人游戏

笔记(2)中利用tensorflow.js实现了一个经典的机器学习问题——CNN识别手写数字集MNIST。这篇笔记将利用web摄像头识别图像并判断上、下、左、右来玩吃豆人游戏。参考官方示例Transfer learning - Train a neural network to predict from webcam data,修改了部分代码。

1、首先引入已训练好的模型,mobilenet

async function loadMobilenet() {
  const mobilenet = await tf.loadModel('./model.json');
  const layer = mobilenet.getLayer('conv_pw_13_relu');
  return tf.model({inputs: mobilenet.inputs, outputs: layer.output});
}

其中函数返回的tf.model中输入还是mobilenet的原始输入,输出为mobilenet的“conv_pw_13_relu”层。一般而言,因为越靠后所包含的训练信息越多,所以应选择已训练好的模型中越靠后的层。

2、定义摄像头的类webcam

webcam.js文件内容如下。

import * as tf from '@tensorflow/tfjs';

export class Webcam {
  constructor(webcamElement) {
    this.webcamElement = webcamElement;
  }

  capture() {
    return tf.tidy(() => {
      const webcamImage = tf.fromPixels(this.webcamElement);
      const croppedImage = this.cropImage(webcamImage);
      const batchedImage = croppedImage.expandDims(0);
      return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
    });
  }

  cropImage(img) {
    const size = Math.min(img.shape[0], img.shape[1]);
    const centerHeight = img.shape[0] / 2;
    const beginHeight = centerHeight - (size / 2);
    const centerWidth = img.shape[1] / 2;
    const beginWidth = centerWidth - (size / 2);
    return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
  }

  adjustVideoSize(width, height) {
    const aspectRatio = width / height;
    if (width >= height) {
      this.webcamElement.width = aspectRatio * this.webcamElement.height;
    } else if (width < height) {
      this.webcamElement.height = this.webcamElement.width / aspectRatio;
    }
  }

  async setup() {
    return new Promise((resolve, reject) => {
      const navigatorAny = navigator;
      navigator.getUserMedia = navigator.getUserMedia ||
          navigatorAny.webkitGetUserMedia || 
          navigatorAny.mozGetUserMedia ||
          navigatorAny.msGetUserMedia;
      if (navigator.getUserMedia) {
        navigator.getUserMedia(
            {video: true},
            stream => {
              this.webcamElement.src = window.URL.createObjectURL(stream);
              this.webcamElement.addEventListener('loadeddata', async () => {
                this.adjustVideoSize(
                    this.webcamElement.videoWidth,
                    this.webcamElement.videoHeight);
                resolve();
              }, false);
            },
            error => {
              document.querySelector('#no-webcam').style.display = 'block';
            });
      } else {
        reject();
      }
    });
  }
}

其中构造器传入DOM中的

const webcam = new Webcam(document.getElementById('webcam'));
(未完待续)

你可能感兴趣的:(【tensorflow.js学习笔记(3)】迁移学习——用web摄像头玩吃豆人游戏)