【tensorflow.js学习笔记(2)】CNN识别手写数字集MNIST

笔记(1)中利用tensorflow.js完成了机器学习中曲线拟合的任务,这篇笔记将实现一个经典的机器学习问题——CNN识别手写数字集MNIST。参考官方示例Training on Images: Recognizing Handwritten Digits with a Convolutional Neural Network,修改部分代码并用echarts改写vega。

1、定义mnist数据类

import * as tf from '@tensorflow/tfjs';

const IMAGE_SIZE = 784;//图片大小28*28
const NUM_CLASSES = 10;//类别数
const NUM_DATASET_ELEMENTS = 65000;//总样本数
const NUM_TRAIN_ELEMENTS = 55000;//训练样本数
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;//测试样本数

const MNIST_IMAGES_SPRITE_PATH = './src/mnist_images.png';//mnist图像
const MNIST_LABELS_PATH = './src/mnist_labels_uint8';//mnist图像对应的类别

export class MnistData {
  constructor() {
    this.shuffledTrainIndex = 0;
    this.shuffledTestIndex = 0;
  }

  async load() {
    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = '';
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;
        const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
            datasetBytesBuffer,
            i * IMAGE_SIZE * chunkSize * 4,
            IMAGE_SIZE * chunkSize
          );
          ctx.drawImage(img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize);

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

    this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  }

  nextTrainBatch(batchSize) {
    return this.nextBatch(
      batchSize,
      [this.trainImages, this.trainLabels],
      () => {
        this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length;
        return this.trainIndices[this.shuffledTrainIndex];
      }
    );
  }

  nextTestBatch(batchSize) {
    return this.nextBatch(
      batchSize,
      [this.testImages, this.testLabels],
      () => {
        this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length;
        return this.testIndices[this.shuffledTestIndex];
      }
    );
  }

  nextBatch(batchSize, data, index) {
    const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

    for (let i = 0; i < batchSize; i++) {
      const idx = index();
      const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
      batchImagesArray.set(image, i * IMAGE_SIZE);
      const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
      batchLabelsArray.set(label, i * NUM_CLASSES);
    }

    const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

    return {xs, labels};
  }
}

mnist数据类在构造器内声明两个index,分别是训练过程的洗牌index和测试过程的洗牌index。引入洗牌index是为了防止模型训练受到传入图像顺序的影响。假设不洗牌,先将所有1的图像传入模型进行训练,那么此时训练的模型将学会预测1的手写体;之后传入所有2的图像,则模型将切换到仅预测2(这样会最小化损失函数);则这样永远无法完整的对全部数据集进行预测。

之后定义方法load(),该方法将图片mnist_images.png进行切割,并从mnist_labels_uint8中找到对应的label。之后的nextTrainBatch和nextTestBatch将分别返回训练样本和测试样本。

2、CNN的构建

import * as tf from '@tensorflow/tfjs';
import {MnistData} from './data';
import * as ui from './ui';

const model = tf.sequential();
model.add(tf.layers.conv2d({
  inputShape: [28, 28, 1],
  kernelSize: 5,
  filters: 8,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({
  poolSize: [2, 2],
  strides: [2, 2]
}));
model.add(tf.layers.conv2d({
  kernelSize: 5,
  filters: 16,
  strides: 1,
  activation: 'relu',
  kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({
  poolSize: [2, 2],
  strides: [2, 2]
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
  units: 10,
  kernelInitializer: 'varianceScaling',
  activation: 'softmax'
}));

const LEARNING_RATE = 0.15;
const optimizer = tf.train.sgd(LEARNING_RATE);
model.compile({
  optimizer: optimizer,
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

首先定义模型为tf.sequential(),该模型中张量将连续地从一层传递到下一层。之后分别加入卷积层、池化层、卷积层、池化层、flatten层(将输入展为向量)和dense层(完全连接层)。

其中卷积层是2维卷积;inputShape是传入数据的维数(第二个卷积层可以不指定inputShape,tf将从前一层的输出推断该值),三个值分别为行、列、深度,mnist图像格式为28*28像素点,深度为1,因为只有一个颜色通道;kernelSize是应用到输入数据上的滑动滤波器窗口的大小,5代表5*5矩形卷积窗口;filters指滤波器窗口的数量,8代表有8个滤波器;strides指滑动窗口的步长,1代表每次将以1像素为单位滑动滤波器;activation为激活函数,relu代表线性整流函数,函数形状如下所示;kernelInitializer用于随机初始化模型权重,这里使用VarianceScaling方法初始化模型。

【tensorflow.js学习笔记(2)】CNN识别手写数字集MNIST_第1张图片

池化层使用的是二维最大池,通过计算该层每个滑动窗口的最大值来降低纬度。其中poolSize是滑动窗口的大小,[2, 2]代表2*2的矩形窗口;strides代表滑动窗口移动的步长,[2, 2]代表窗口将在水平和垂直方向是以2像素为单位进行移动。

flatten层将上一层的输出平铺到一个矢量上。dense层(完全连接层)将执行最终的分类任务。其中units是输出的激活数,10代表将有10种不同的输出,满足mnist的10种分类(数字0-9);kernelInitializer设为VarianceScaling初始化方法;分类任务的最后一层激活函数activation通常设为softmax,该函数将10维输出向量归一化为概率分布,以便我们知道该样本属于10个类中每个类的概率。

定义学习率为0.15,优化器为随机梯度下降法(SGD);损失函数为categoricalCrossentropy,即分类任务的交叉熵;评价指标为准确率accuracy,即所有预测中正确预测的百分比。之后编译模型。

3、模型训练

const BATCH_SIZE = 64;
const TRAIN_BATCHES = 150;
const TEST_BATCH_SIZE = 1000;
const TEST_ITERATION_FREQUENCY = 5;

async function train() {
  ui.isTraining();

  const lossValues = [];
  const accuracyValues = [];

  for (let i = 0; i < TRAIN_BATCHES; i++) {
    const batch = data.nextTrainBatch(BATCH_SIZE);

    let testBatch;
    let validationData;
    if (i % TEST_ITERATION_FREQUENCY === 0) {
      testBatch = data.nextTestBatch(TEST_BATCH_SIZE);
      validationData = [
        testBatch.xs.reshape([TEST_BATCH_SIZE, 28, 28, 1]), testBatch.labels
      ];
    }

    const history = await model.fit(
      batch.xs.reshape([BATCH_SIZE, 28, 28, 1]),
      batch.labels,
      {batchSize: BATCH_SIZE, validationData, epochs: 1}
    );
    const loss = history.history.loss[0];
    const accuracy = history.history.acc[0];
    lossValues.push([i, loss]);

    if (testBatch != null) {
      accuracyValues.push([i, accuracy]);
    }

    batch.xs.dispose();
    batch.labels.dispose();
    if (testBatch != null) {
      testBatch.xs.dispose();
      testBatch.labels.dispose();
    }

    await tf.nextFrame();
  }
  ui.plot(lossValues, accuracyValues);
}

async function showPredictions() {
  const testExamples = 100;
  const batch = data.nextTestBatch(testExamples);

  tf.tidy(() => {
    const output = model.predict(batch.xs.reshape([-1, 28, 28, 1]));

    const axis = 1;
    const labels = Array.from(batch.labels.argMax(axis).dataSync());
    const predictions = Array.from(output.argMax(axis).dataSync());
    
    ui.showTestResults(batch, predictions, labels);
  });
}

let data;
async function load() {
  data = new MnistData();
  await data.load();
}

async function mnist() {
  await load();

  await train();

  showPredictions();
}


mnist();

训练时首先调用ui.isTraining()将training输出到文档,之后进行TEST_BATCH_SIZE = 1000轮迭代,每次记录损失函数的值及准确率,并在训练结束后进行可视化展示。之后调用showPredictions方法将测试样本的预测值与实际像素图输出到文档。

4、可视化部分

const echarts = require('echarts');

const statusElement = document.getElementById('status');
const imagesElement = document.getElementById('images');

export function isTraining() {
  statusElement.innerText = 'Training...';
}

export function showTestResults(batch, predictions, labels) {
  statusElement.innerText = 'Testing...';

  const testExamples = batch.xs.shape[0];
  let totalCorrect = 0;
  for (let i = 0; i < testExamples; i++) {
    const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]);

    const div = document.createElement('div');
    div.className = 'pred-container';

    const canvas = document.createElement('canvas');
    canvas.className = 'prediction-canvas';
    draw(image.flatten(), canvas);

    const pred = document.createElement('div');

    const prediction = predictions[i];
    const label = labels[i];
    const correct = prediction === label;

    pred.className = `pred ${(correct ? 'pred-correct' : 'pred-incorrect')}`;
    pred.innerText = `pred: ${prediction}`;

    div.appendChild(pred);
    div.appendChild(canvas);

    imagesElement.appendChild(div);
  }
}

const lossChart = echarts.init(document.getElementById('lossChart'));
const accuracyChart = echarts.init(document.getElementById('accuracyChart'));
export function plot(lossValues, accuracyValues) {
  lossChart.setOption({
    title: {
      text: 'Loss Values'
    },
    xAxis: {
      type: 'value'
    },
    yAxis: {
      type: 'value'
    },
    series: [{
      name: 'loss',
      type: 'line',
      data: lossValues
    }]
  });
  accuracyChart.setOption({
    title: {
      text: 'Accuracy Values'
    },
    xAxis: {
      type: 'value'
    },
    yAxis: {
      type: 'value'
    },
    series: [{
      name: 'accuracy',
      type: 'line',
      data: accuracyValues
    }]
  });
}

export function draw(image, canvas) {
  const [width, height] = [28, 28];
  canvas.width = width;
  canvas.height = height;
  const ctx = canvas.getContext('2d');
  const imageData = new ImageData(width, height);
  const data = image.dataSync();
  for (let i = 0; i < height * width; ++i) {
    const j = i * 4;
    imageData.data[j + 0] = data[i] * 255;
    imageData.data[j + 1] = data[i] * 255;
    imageData.data[j + 2] = data[i] * 255;
    imageData.data[j + 3] = 255;
  }
  ctx.putImageData(imageData, 0, 0);
}

【tensorflow.js学习笔记(2)】CNN识别手写数字集MNIST_第2张图片

BUG1:

  lossChart.setOption({
    series: [{
      name: 'loss',
      type: 'line',
      data: lossValues
    }]
  });

当绘图时仅传入series参数时,Echarts报错:

Uncaught (in promise) TypeError: Cannot read property 'get' of undefined

此时将xAxis、yAxis等添加进options中即可。

  lossChart.setOption({
    title: {
      text: 'Loss Values'
    },
    xAxis: {
      type: 'value'
    },
    yAxis: {
      type: 'value'
    },
    series: [{
      name: 'loss',
      type: 'line',
      data: lossValues
    }]
  });

BUG2:

【tensorflow.js学习笔记(2)】CNN识别手写数字集MNIST_第3张图片

报错 Uncaught Error: Unsupported core optimizer type: t

原因是相关依赖未正确安装,cd到package.json同目录,运行yarn命令安装相关包。

完整程序见我的github,具体步骤为:

step1 新建文件夹,cmd输入git clone [email protected]:orangecsy/tfjs-exercise.git,cd 2进入文件夹2;

step2 cmd输入webpack,打包;

step3 cd dist进入dist文件夹,cmd中输入http-server(需先npm install http-server)或使用webpack配置开发服务器;

step4 浏览器中输入http://127.0.0.1:8080/,即为结果。

你可能感兴趣的:(【tensorflow.js学习笔记(2)】CNN识别手写数字集MNIST)