简介
Keras是一款非常流行的深度学习模型开发框架,基于python,语法简洁,封装程度高,只需十几行代码就可以构建一个深度神经网络。
Keras.js是一个可以在浏览器中运行深度神经网络的JS框架,支持CPU,GPU计算。区别于Keras,Keras.js只能运行已经调试好的模型,无法进行模型训练。
模型
借鉴这篇文章,开发一个识别圣诞老人的神经网络。本文不涉及Keras的开发细节,感兴趣的同学可以去原文查看。这里直接给出python代码
def build_model():
model = models.Sequential()
model.add(layers.Conv2D(20,(5,5),activation='relu',input_shape=(128,128,3)))
model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2)))
model.add(layers.Conv2D(50,(5,5),activation='relu',padding='same'))
model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2)))
model.add(layers.Flatten())
model.add(layers.Dense(500,activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
loss='binary_crossentropy',
metrics=['acc'])
return model
数据
标注数据是AI模型的原料,数据搜集特别是图片搜集是前端可以介入的一个环节。基于React,开发了一款chrome图片批量下载插件GetThemAll
,方便我们进行标记图片搜集。
插件 地址,安装好插件。然后去谷歌图片搜索“santa”, 使用插件标记不需要的图片,然后下载到本地的santa文件夹,通过谷歌图片可以搜集到400张圣诞老人的图片。
接着我们再下载一些非圣诞老人的图片,搜索“object”,同样的使用GetThemAll插件下载大约400张图片到本地的non_santa文件夹中。
除了训练数据集,我们还需要一个测试数据集用来衡量模型的泛化能力。在本地新建一个test文件夹,把刚刚准备好的训练集里面的最后100张圣诞老人图片移到test文件夹下的santa文件中,同样的,移动100张非圣诞老人图片到non_stanta文件中。
有了标记数据,我们就可以进行模型训练啦。具体的训练过程请见pyton代码,这里直接给出训练的结果,蓝点表示训练数据集准确率,蓝线表示测试数据集准备率,模型有着明显的High Variance问题,不过这个bug留给深度学习的专家们解决吧,这里就假设这个模型可用。
迁移
上一步训练出的模型keras_santa.h5
(h5是文件后缀,和HTML5没啥关系)不能直接给KerasJS使用,需要通过KerasJS提供的转换工具转换后,方可被KerasJS加载解析。
./encoder.py keras_santa.h5
转换后,得到了keras_santa.bin
文件,20M左右,这个文件包含了神经网络结构和所有参数,可以被KerasJS加载。
KerasJS
通过上面的步骤,我们得到了一个训练完成的CNN神经网络以及全部参数,这个网络结构和参数全部保存在keras_santa.bin文件中。接下来,我们只需要在浏览器中复原上面的神经网络,然后就可以开始做预测啦。
使用webpack配合React,搭建一套简单的开发环境。做好了基础工作,就可以开始第一步开发,加载神经网络模型文件keras_santa.bin:
const model = new KerasJS.Model({
filepath: 'http://localhost:3000/keras_santa.bin',
gpu: false
})
//KerasJS提供模型加载进度接口,考虑到模型文件体积非常大,这个接口会经常用到
model.events.on('loadingProgress', (progress) => {
this.setState({
loadingtitle: '模型加载',
progress: parseInt(progress)
})
})
使用上面的模型做预测前,需要将数据转化成模型能够接受的数据格式。这个圣诞老人网络需要数据输入格式为(128,128,3),也即是图片需要为128x128分辨率,只能包含RGB三个分量。
借助canvas,可以实现图片分辨率转换:
_updateImageSrc(imgid) {
const ctx = this.refs.canvas.getContext('2d');
const imgdom = document.createElement('img');
imgdom.src = `http://localhost:3000/${imgid}.jpeg`
this.setState({
prediction:0
})
imgdom.onload = ()=>{
ctx.drawImage(imgdom,0,0,128,128)
const imagedata = ctx.getImageData(0,0,128,128)
const processeddata = ImageDataUtils.preprocess(imagedata)
setTimeout(()=>{
this.doPrediction(processeddata)
},100);
}
}
注意preprocess,通过canvas获取到的图片资源包含了rgba四个维度,prepross返回这4个维度中的前3个维度,也即rgb,同时将数据标准化:
export default class ImageDataUtils {
static preprocess(imageData) {
const {
width,
height,
data
} = imageData;
const dataTensor = ndarray(new Float32Array(data),[width,height,4])
const dataProcessedTensor = ndarray(new Float32Array(width*height*3),[width,height,3])
//从[0,255]转化到[0,1]
ops.divseq(dataTensor,255)
//获取R数据
ops.assign(dataProcessedTensor.pick(null,null,0),dataTensor.pick(null,null,0))
//获取G数据
ops.assign(dataProcessedTensor.pick(null,null,1),dataTensor.pick(null,null,1))
//获取B数据
ops.assign(dataProcessedTensor.pick(null,null,2),dataTensor.pick(null,null,2))
const preprocessedData = dataProcessedTensor.data
return preprocessedData
}
}
最后,使用上面返回的数据做预测
async doPrediction(imagedata) {
if(!this.model) return;
const inputname = this.model.inputLayerNames[0]
const inputdata = {[inputname]: imagedata}
const prediction = await this.model.predict(inputdata)
this.setState({
prediction: prediction.output[0]
})
}
注意
可以看到,KerasJS在预测过程中,整个页面无法响应用户操作。这是因为神经网络计算过程中占用了大量CPU资源,从而致使页面卡顿。下一篇文章中,我们将介绍如何使用WebGL,将计算过程转移到GPU,达到实现前端高性能计算的目的。
相关资源
1.Image classification with Keras and deep learning, Adrain Rosebrock
- GetThemAll, eeandrew
- React Keras,eeandrew