KerasJS初探

简介

Keras是一款非常流行的深度学习模型开发框架,基于python,语法简洁,封装程度高,只需十几行代码就可以构建一个深度神经网络。
Keras.js是一个可以在浏览器中运行深度神经网络的JS框架,支持CPU,GPU计算。区别于Keras,Keras.js只能运行已经调试好的模型,无法进行模型训练。


KerasJS初探_第1张图片
KerasJS

KerasJS初探_第2张图片
KerasJS流程

模型

借鉴这篇文章,开发一个识别圣诞老人的神经网络。本文不涉及Keras的开发细节,感兴趣的同学可以去原文查看。这里直接给出python代码

def build_model():
    model = models.Sequential()
    model.add(layers.Conv2D(20,(5,5),activation='relu',input_shape=(128,128,3)))
    model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2)))
    model.add(layers.Conv2D(50,(5,5),activation='relu',padding='same'))
    model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(500,activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
                  loss='binary_crossentropy',
                  metrics=['acc'])
    return model

数据

标注数据是AI模型的原料,数据搜集特别是图片搜集是前端可以介入的一个环节。基于React,开发了一款chrome图片批量下载插件GetThemAll,方便我们进行标记图片搜集。

KerasJS初探_第3张图片
GetThemAll

插件 地址,安装好插件。然后去谷歌图片搜索“santa”, 使用插件标记不需要的图片,然后下载到本地的santa文件夹,通过谷歌图片可以搜集到400张圣诞老人的图片。
接着我们再下载一些非圣诞老人的图片,搜索“object”,同样的使用GetThemAll插件下载大约400张图片到本地的non_santa文件夹中。
除了训练数据集,我们还需要一个测试数据集用来衡量模型的泛化能力。在本地新建一个test文件夹,把刚刚准备好的训练集里面的最后100张圣诞老人图片移到test文件夹下的santa文件中,同样的,移动100张非圣诞老人图片到non_stanta文件中。
KerasJS初探_第4张图片
数据集结构

有了标记数据,我们就可以进行模型训练啦。具体的训练过程请见pyton代码,这里直接给出训练的结果,蓝点表示训练数据集准确率,蓝线表示测试数据集准备率,模型有着明显的High Variance问题,不过这个bug留给深度学习的专家们解决吧,这里就假设这个模型可用。


KerasJS初探_第5张图片
训练结果

迁移

上一步训练出的模型keras_santa.h5(h5是文件后缀,和HTML5没啥关系)不能直接给KerasJS使用,需要通过KerasJS提供的转换工具转换后,方可被KerasJS加载解析。

./encoder.py keras_santa.h5

转换后,得到了keras_santa.bin文件,20M左右,这个文件包含了神经网络结构和所有参数,可以被KerasJS加载。

KerasJS

通过上面的步骤,我们得到了一个训练完成的CNN神经网络以及全部参数,这个网络结构和参数全部保存在keras_santa.bin文件中。接下来,我们只需要在浏览器中复原上面的神经网络,然后就可以开始做预测啦。
使用webpack配合React,搭建一套简单的开发环境。做好了基础工作,就可以开始第一步开发,加载神经网络模型文件keras_santa.bin:

const model = new KerasJS.Model({
   filepath: 'http://localhost:3000/keras_santa.bin',
   gpu: false
})
//KerasJS提供模型加载进度接口,考虑到模型文件体积非常大,这个接口会经常用到
model.events.on('loadingProgress', (progress) => {
      this.setState({
        loadingtitle: '模型加载',
        progress: parseInt(progress)
      })
})

使用上面的模型做预测前,需要将数据转化成模型能够接受的数据格式。这个圣诞老人网络需要数据输入格式为(128,128,3),也即是图片需要为128x128分辨率,只能包含RGB三个分量。


KerasJS初探_第6张图片
输入数据

借助canvas,可以实现图片分辨率转换:

_updateImageSrc(imgid) {
    const ctx = this.refs.canvas.getContext('2d');
    const imgdom = document.createElement('img');
    imgdom.src = `http://localhost:3000/${imgid}.jpeg`
    this.setState({
      prediction:0
    })
    imgdom.onload = ()=>{
      ctx.drawImage(imgdom,0,0,128,128)
      const imagedata = ctx.getImageData(0,0,128,128)
      const processeddata = ImageDataUtils.preprocess(imagedata)
      setTimeout(()=>{
        this.doPrediction(processeddata)
      },100);
    }
  }

注意preprocess,通过canvas获取到的图片资源包含了rgba四个维度,prepross返回这4个维度中的前3个维度,也即rgb,同时将数据标准化:

export default class ImageDataUtils {
  static preprocess(imageData) {
    const {
      width,
      height,
      data
    } = imageData;
    const dataTensor = ndarray(new Float32Array(data),[width,height,4])
    const dataProcessedTensor = ndarray(new Float32Array(width*height*3),[width,height,3])
    //从[0,255]转化到[0,1]
    ops.divseq(dataTensor,255)
    //获取R数据
    ops.assign(dataProcessedTensor.pick(null,null,0),dataTensor.pick(null,null,0))
    //获取G数据
    ops.assign(dataProcessedTensor.pick(null,null,1),dataTensor.pick(null,null,1))
    //获取B数据
    ops.assign(dataProcessedTensor.pick(null,null,2),dataTensor.pick(null,null,2))
    const preprocessedData = dataProcessedTensor.data
    return preprocessedData
  }   
}

最后,使用上面返回的数据做预测

async doPrediction(imagedata) {
    if(!this.model) return;
    const inputname = this.model.inputLayerNames[0]
    const inputdata = {[inputname]: imagedata}
    const prediction = await this.model.predict(inputdata)
    this.setState({
      prediction: prediction.output[0]
    })
  }
KerasJS初探_第7张图片
预测

注意

可以看到,KerasJS在预测过程中,整个页面无法响应用户操作。这是因为神经网络计算过程中占用了大量CPU资源,从而致使页面卡顿。下一篇文章中,我们将介绍如何使用WebGL,将计算过程转移到GPU,达到实现前端高性能计算的目的。

相关资源

1.Image classification with Keras and deep learning, Adrain Rosebrock

  1. GetThemAll, eeandrew
  2. React Keras,eeandrew

你可能感兴趣的:(KerasJS初探)