tensorflow.js基础入门

前提

链接: tensorflow.js官网
官网改版以后文档已经比较完善,有兴趣的同学一起来学习吧.
链接: 超简单视频教程
本文所有代码都是从视频里学的,复制下来保存成html文件,就可以训练自己的数字识别模型了.

基础概念的个人理解

机器学习: 例如:通常写一段程序去界定考试成绩是否及格,if (score>=60) 及格 else if (score<60) 不及格.
机器学习是这样实现功能的:找一堆成绩和对应标签的数据集合
60-及格,63-及格,64-及格,72-及格…
54-不及格,44-不及格,59-不及格…
很多很多数据输入到模型中去,机器自己慢慢就会在及格线上划出一条边界,以后即便是遇到没见过的数据,它也能界定出这个数据到底属于及格还是不及格.
这么费事就只实现了一个if else的功能?数据量小并且数据结构简单时if else没问题,但逻辑判断如果遇到没见过的数据,程序就会报错或崩溃,类似图片或声音这些复杂数据类型就没办法去界定了.

tensor: 可以理解成N维数组.
使用面向对象思想写程序,就是要把事物抽象成具有属性和方法的对象,tensor就是把事物抽象成数学模型能够处理的数据类型.在tensorflow里面就是多维数组,几维数组就是几阶张量.
flow: tensor类型的数据在模型里面正向反向来回流动,模型类似人类的神经网络,一层一层的传递.模型也可以简单看成一个带有参数的函数,使用训练数据训练时,模型就会调整自身参数逐渐找到60分及格线.

代码

<!DOCTYPE html>
<html lang="en">
<head>
	<meta charset="UTF-8">
	<title>tfjs</title>
	<script src="https://cdn.bootcss.com/vue/2.6.10/vue.min.js"></script>
	<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>
</head>
<body>
<div id="app" @mouseup="onMouseup">

	<canvas 
		ref="draw" width="200px" height="200px" 
		style="border-style: dashed;" @mousedown="onMousedown" 
		@mousemove="onMousemove"  
		>
	</canvas><br>
	<button @click="clearCanvas">清空</button><br><br>
	<canvas 
		width="28px" height="28px" 
		style="border-style: solid;border-color: red; background-color: black;"
		ref="preview"
	>
	</canvas><br><br>

	关联数字:<input type="text" v-model="targetNum">
	<button @click="training">训练</button>

	<h3>识别</h3>
	<button @click="onRegNum">预测</button>
	<p>{{reg}}</p>

</div>

<script>
window.onload = ()=> {
	new Vue({
		el: '#app',
		data: {
			reg: 'xxx',
			drawing: false,
			draw: null,
			preview: null,
			model: null,
			targetNum: 0,
		},
		mounted() {
			let c2d = this.draw = this.$refs.draw.getContext('2d')
			this.preview = this.$refs.preview.getContext('2d')
			c2d.lineWidth = 20
			c2d.lineCap = 'round'
			c2d.lineJoin = 'round'
			this.model = tf.sequential({
				layers: [
					tf.layers.inputLayer({
						inputShape: [784]
					}),
					tf.layers.dense({units: 10}),  //输出空间,10个数字
					tf.layers.softmax()    //输出空间所有值之和为1
				]
			})
			this.model.compile({
			  optimizer: 'sgd',    //优化器
			  loss: 'categoricalCrossentropy', //损失函数
			  metrics: ['accuracy']  //logs里的acc
			})
		},
		methods: {
			getImageData() {
				let image = this.preview.getImageData(0,0,28,28)
				let data = image.data
				let pixelData = []
				let color
				for (let i = 0; i < data.length; i+=4) {
					color = (data[i]+data[i+1]+data[i+2])/3
					//空白的地方保持为0,有颜色的地方才有值,特征才明显,训练出来的效果更好,转成只有0和1,效果最好
					pixelData.push(Math.round((255-color)/255))
				}
				//长度转成784,单通道图片
				return pixelData
			},
			onMousedown(e) {
				this.drawing = true
				this.draw.beginPath()
				this.draw.moveTo(e.offsetX,e.offsetY)
			},
			onMousemove(e) {
				if (this.drawing) {
					this.draw.lineTo(e.offsetX,e.offsetY)
					this.draw.stroke()
				}
			},
			onMouseup(e) {
				this.drawing = false
				//this.preview.clearRect(0,0,28,28)
				this.preview.fillStyle = 'white'  //实际是透明的,必须填充白色
				this.preview.fillRect(0,0,28,28)
				this.preview.drawImage(this.$refs.draw,0,0,28,28)
			},
			clearCanvas() {
				this.draw.clearRect(0,0,200,200)
			},
			async training() {
				let data = this.getImageData()
				//tf.tensor(data).print()
				//[1,0,0,0,0,0,0,0,0,0]  代表1
				//[0,1,0,0,0,0,0,0,0,0]  代表2...
				//生成上面的一阶张量
				let targetTensor = tf.oneHot(parseInt(this.targetNum),10)
				console.log('start train')
				await this.model.fit(tf.tensor([data]),tf.tensor([targetTensor.arraySync()]),{
					epochs: 30, //训练次数
					callbacks: { //每次的回调
						onEpochEnd(epoch,logs) {
							console.log(epoch,logs)
						}
					}
				})
				console.log('end train')
			},
			async onRegNum() {
				let data = this.getImageData()
				let predictions = this.model.predict(tf.tensor([data])) //结果也是tensor
				//获取tensor第一层里的最大值的index,正好就是数字本身
				let result = predictions.argMax(1).arraySync()[0]
				this.reg = result
			}
		}

	})
}
</script>
</body>
</html>

照着视频里敲出来的代码,此代码没有页面样式,看着丑陋一些,复制所有代码保存成html格式,用浏览器打开,再打开控制台就可以开始训练模型了,页面一刷新,模型就要重新开始训练.
电脑有独显的话会优先使用gpu去计算,但是比较老的独显好像支持的并不好,如果看到控制台里的loss一直不变,可能就是这个原因.需要设置浏览器禁用gpu加速.
tensorflow.js基础入门_第1张图片
图片数据其实就是一个个像素点组成的张量,正常情况就会如图所示,虚线框内画一个数字,关联数字里面填入你写的数字,点击训练控制台就会输出损失值,值应该越来越小.
0到9每个数字手写10遍左右训练,基本上就能得到一个还算可以的模型.
再手写几个数字,使用预测按钮看看值是不是对的.

结论

这是机器学习的hello world级代码,其他主要功能如官网模型部分的介绍
tensorflow.js基础入门_第2张图片
入门的话对模型不用过多关注,主要是tensor类型的数据处理,以及如何突出数据的特征.
以上只是粗浅的理解,需要配合官方文档结合代码,如果能够把代码理解透彻,基本上就算是入门啦.
理解有误的地方还请多多指正.
机器学习连入门都很难,普通程序员也不怎么会用到,至少了解一下机器到底是怎么学习的,毕竟以后的世界可能任何行业都脱离不开机器学习.

你可能感兴趣的:(tensorflow)