TensorFlow.js图片分类的迁移学习

来源:https://codelabs.developers.google.com/codelabs/tensorflowjs-teachablemachine-codelab/index.html

1.简介

在此代码实验室中,您将学习如何构建一个简单的“可教学的机器”,这是一个自定义图像分类器,您将使用TensorFlow.js(一个功能强大且灵活的Java机器学习库)在浏览器中进行训练。首先,您将加载并运行一个流行的预训练模型MobileNet,以在浏览器中进行图像分类。然后,您将使用一种称为“转移学习”的技术,该技术使用预先训练的MobileNet模型引导我们的训练,并对其进行自定义以针对您的应用程序进行训练。

该代码实验室将不会讲授可教机器应用背后的理论。如果您对此感到好奇,请查看本教程。
您将学到什么

  • 如何加载预先训练的MobileNet模型并预测新数据
  • 如何通过网络摄像头做出预测
  • 如何使用MobileNet的中间激活来在您使用网络摄像头动态定义的一组新类上进行迁移学习

因此,让我们开始吧!

2.要求

要完成此代码实验室,您将需要:

  1. Chrome的最新版本或其他现代浏览器。
  2. 文本编辑器,可以通过Codepen或Glitch之类在您的计算机上本地运行,也可以在网络上运行。
  3. 了解HTML,CSS,JavaScript和Chrome DevTools(或您喜欢的浏览器devtools)。
  4. 对神经网络的高级概念理解。如果您需要介绍或复习,请考虑观看3blue1brown的视频或Ashi Krishnan的Java深度学习视频。

注意:如果您在CodeLab信息亭中,我们建议使用glitch.com完成此Codelab。我们为您设置了一个入门项目,以重新混合以加载tensorflow.js。

3.加载TensorFlow.js和MobileNet模型

在编辑器中打开index.html并添加以下内容:


  
    
    
    
  
  
    

4.设置MobileNet以在浏览器中进行推断

接下来,在代码编辑器中打开/创建文件index.js,并包含以下代码:

let net;

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Make a prediction through the model on our image.
  const imgEl = document.getElementById('img');
  const result = await net.classify(imgEl);
  console.log(result);
}

app();

5.在浏览器中测试MobileNet推理

要运行该网页,只需在Web浏览器中打开index.html。如果使用的是云控制台,只需刷新预览页面即可。

您应该在开发人员工具的Javascript控制台中看到一只狗的图片,这是MobileNet的最高预测!请注意,下载模型可能需要一点时间,请耐心等待!

图像是否正确分类?

还值得注意的是,这也可以在手机上使用!

6.在浏览器中通过网络摄像头图像运行MobileNet推理

现在,让我们使其更具交互性和实时性。让我们设置网络摄像头,以对通过网络摄像头拍摄的图像进行预测。

首先设置网络摄像头视频元素。打开index.html文件,并在部分中添加以下行,并删除用于加载狗图像的标签:


打开index.js文件,并将webcamElement添加到文件的顶部

const webcamElement = document.getElementById('webcam');

现在,在之前添加的app()函数中,您可以通过图像删除预测,而是创建一个无限循环,该无限循环通过网络摄像头元素进行预测。

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');
  
  // Create an object from Tensorflow.js data API which could capture image 
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);
  while (true) {
    const img = await webcam.capture();
    const result = await net.classify(img);

    document.getElementById('console').innerText = `
      prediction: ${result[0].className}\n
      probability: ${result[0].probability}
    `;
    // Dispose the tensor to release the memory.
    img.dispose();

    // Give some breathing room by waiting for the next animation frame to
    // fire.
    await tf.nextFrame();
  }
}

如果您在网页中打开控制台,现在应该会看到MobileNet预测,并具有在摄像头中收集的每个帧的概率。

这些可能是荒谬的,因为ImageNet数据集看起来与通常出现在网络摄像头中的图像不太相似。一种测试方法是,将手机上的狗的图片放在笔记本电脑摄像头前。

7.在MobileNet预测之上添加自定义分类器

现在,让我们使其更有用。我们将使用网络摄像头动态创建自定义的3类对象分类器。我们将通过MobileNet进行分类,但是这次我们将对特定摄像头图像进行模型的内部表示(激活),并将其用于分类。

我们将使用一个称为“ K最近邻居分类器”的模块,该模块可以有效地使我们将网络摄像头图像(实际上是其MobileNet激活)放入不同的类别(或“类别”)中,并且当用户要求做出预测时,只需选择与我们预测的活动最相似的课程。

在index.html的标记的导入末尾添加KNN分类器的导入(您仍将需要MobileNet,因此请不要删除该导入):

...

...

在video元素下方的index.html中为每个按钮添加3个按钮。这些按钮将用于向模型添加训练图像。

...



...

在index.js的顶部,创建分类器:

const classifier = knnClassifier.create();

更新app函数:

async function app() {
  console.log('Loading mobilenet..');

  // Load the model.
  net = await mobilenet.load();
  console.log('Successfully loaded model');

  // Create an object from Tensorflow.js data API which could capture image 
  // from the web camera as Tensor.
  const webcam = await tf.data.webcam(webcamElement);

  // Reads an image from the webcam and associates it with a specific class
  // index.
  const addExample = async classId => {
    // Capture an image from the web camera.
    const img = await webcam.capture();

    // Get the intermediate activation of MobileNet 'conv_preds' and pass that
    // to the KNN classifier.
    const activation = net.infer(img, 'conv_preds');

    // Pass the intermediate activation to the classifier.
    classifier.addExample(activation, classId);

    // Dispose the tensor to release the memory.
    img.dispose();
  };

  // When clicking a button, add an example for that class.
  document.getElementById('class-a').addEventListener('click', () => addExample(0));
  document.getElementById('class-b').addEventListener('click', () => addExample(1));
  document.getElementById('class-c').addEventListener('click', () => addExample(2));

  while (true) {
    if (classifier.getNumClasses() > 0) {
      const img = await webcam.capture();

      // Get the activation from mobilenet from the webcam.
      const activation = net.infer(img, 'conv_preds');
      // Get the most likely class and confidences from the classifier module.
      const result = await classifier.predictClass(activation);

      const classes = ['A', 'B', 'C'];
      document.getElementById('console').innerText = `
        prediction: ${classes[result.label]}\n
        probability: ${result.confidences[result.label]}
      `;

      // Dispose the tensor to release the memory.
      img.dispose();
    }

    await tf.nextFrame();
  }
}

现在,当您加载index.html页面时,可以使用公共对象或面部/身体手势来捕获这三个类中的每一个的图像。每次单击“添加”按钮之一时,会将一幅图像添加到该班级作为训练示例。在执行此操作时,模型将继续对即将到来的网络摄像头图像进行预测,并实时显示结果。

8.可选:扩展示例

现在尝试添加另一个不表示任何操作的类!

9.您学到了什么

在此代码实验室中,您使用TensorFlow.js实现了一个简单的机器学习Web应用程序。您加载并使用了预先训练的MobileNet模型对网络摄像头中的图像进行分类。然后,您可以自定义模型,以将图像分为三个自定义类别。

请确保访问js.tensorflow.org以获取更多示例和带有代码的演示,以了解如何在应用程序中使用TensorFlow.js。

你可能感兴趣的:(TensorFlow.js图片分类的迁移学习)