整合DL4J训练模型与Web工程

一、前言

    上一篇博客《有趣的卷积神经网络》介绍如何基于deeplearning4j对手写数字识别进行训练,对于整个训练集只训练了一次,正确率是0.9897,随着迭代次数的增加,网络模型将更加逼近训练集,下面是对训练集迭代十次的评估结果,总之迭代次数的增加会更加逼近模型(注:增加迭代次数有时也会发生过拟合,有时候也并非很奏效,具体情况具体分析)。

 Accuracy:        0.9919
 Precision:       0.9919
 Recall:          0.9918
 F1 Score:        0.9918

二、导读

    1、web环境搭建

    2、基于canvas构建前端画图界面

    3、整合dl4j训练模型

三、web环境搭建

    1、eclipse  new一个Maven project ,填好maven坐标,packaging选war

org.dl4j
digitalrecognition
0.0.1-SNAPSHOT
war

    2、配置Jar包依赖,由于servlet-api一般由web容器提供,所以scope为provided,这样不会被打入war包里。


		
			org.springframework
			spring-webmvc
			4.3.4.RELEASE
		
		
			javax.servlet
			servlet-api
			2.5
			provided
		
		
			com.fasterxml.jackson.core
			jackson-core
			2.5.3
		

		
			com.fasterxml.jackson.core
			jackson-annotations
			2.5.3
		

		
			com.fasterxml.jackson.core
			jackson-databind
			2.5.3
		
		
			commons-fileupload
			commons-fileupload
			1.3.1
		
		
			org.deeplearning4j
			deeplearning4j-core
			0.9.1
		
		
			org.nd4j
			nd4j-native-platform
			0.9.1
		
	

    3、为了开发方便,不用把web工程部署到外置web容器,所以在开发时用mavan tomcat插件是比较方便的。运行时mvn tomcat7:run即可


		
			
				org.apache.tomcat.maven
				tomcat7-maven-plugin
				2.2
				
					UTF-8
					/
					8080
					org.apache.coyote.http11.Http11NioProtocol
					1000
					100
				
			
		
	

    4、web常规配置web.xml,filter、servlet、listener这里就略去了。

四、前端canvas画图实现

    1、html元素、css




数字识别


	
	
	
识别结果:

    2、js代码实现在canvas画布连线操作,并将图片转化为base64格式,ajax发送给后端,这里画布的大小是280px,所以图片到了后端,需要缩小至十分之一。


    整体呈现的界面如下,可以画图。

整合DL4J训练模型与Web工程_第1张图片

五、后端java代码

@RequestMapping("/digitalRecognition")
@Controller
public class DigitalRecognitionController implements InitializingBean {
	private MultiLayerNetwork net;

	@ResponseBody
	@RequestMapping("/predict")
	public int predict(@RequestParam(value = "img") String img) throws Exception {
		String imagePath= generateImage(img);//将base64图片转化为png图片
		imagePath= zoomImage(imagePath);//将图片缩小至28*28
		DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
		ImageRecordReader testRR = new ImageRecordReader(28, 28, 1);
		File testData = new File(imagePath);
		FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS);
		testRR.initialize(testSplit);
		DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, 1);
		testIter.setPreProcessor(scaler);
		INDArray array = testIter.next().getFeatureMatrix();
		return net.predict(array)[0];
	}

	private String generateImage(String img) {
		BASE64Decoder decoder = new BASE64Decoder();
		String filePath = WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png";
		try {
			byte[] b = decoder.decodeBuffer(img);
			for (int i = 0; i < b.length; ++i) {
				if (b[i] < 0) {
					b[i] += 256;
				}
			}
			OutputStream out = new FileOutputStream(filePath);
			out.write(b);
			out.flush();
			out.close();
		} catch (Exception e) {
			e.printStackTrace();
		}
		return filePath;
	}
	
	private String zoomImage(String filePath){
		String imagePath=WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png";
		try {
			BufferedImage bufferedImage = ImageIO.read(new File(filePath));
			Image image = bufferedImage.getScaledInstance(28, 28, Image.SCALE_SMOOTH);
			BufferedImage tag = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
			Graphics g = tag.getGraphics();
			g.drawImage(image, 0, 0, null); // 绘制处理后的图
			g.dispose();
			ImageIO.write(tag, "png",new File(imagePath));
		} catch (Exception e) {
			e.printStackTrace();
		}
		return imagePath;
	}
	

	@Override
	public void afterPropertiesSet() throws Exception {
		net = ModelSerializer.restoreMultiLayerNetwork(new File(WebConstant.WEB_ROOT + "model/minist-model.zip"));
	}

}

    代码说明:

    1、InitializingBean是spring bean生命周期中的一个环节,spring构建bean的过程中会执行afterPropertiesSet方法,这里用这个方法来加载已经定型的网络。

      2、generateImage是用来将前端传过来的base64串转化为png格式。

      3、zoomImage方法将前端的280*280缩小至28*28和训练数据一致,并存到webroot的upload目录下。

     4、predict进行预测,将转化好的28*28的图片读取出来,张量化,把像素点的值压缩至0到1,预测,最后结果是一个数组,由于只有一张图片,取数组的第一个元素即可。

六、测试,mvn tomcat7:run,浏览器访问http://localhost:8080即可玩手写数字识别了

    整合DL4J训练模型与Web工程_第2张图片

           整合DL4J训练模型与Web工程_第3张图片

    测试结果马马虎虎,大体上实现了基本功能。

    git地址:https://gitee.com/lxkm/dl4j-demo/tree/master/digitalrecognition

    快乐源于分享。

 

 

 

 

 

 

你可能感兴趣的:(deeplearning4j)