pytorch学习2

分类问题

手写数字数据集

pytorch学习2_第1张图片

其中,每个数字图片大小是28 x 28,矩阵中每个元素的大小为[0,1]区间的灰度值,将二维矩阵拉平(flat)为一维784,数据量不变,这样能忽略上下位置相关性,甚至左右位置相关性也可忽略,再插入一个维度变为[1,784]
pytorch学习2_第2张图片

线性模型能解决吗

一个简单的线性模型为:y = w * x + b
但对于手写数字来说,用一个简单的线性模型,是不可能解决问题的。
故用以上三个线性函数进行嵌套

pytorch学习2_第3张图片
其中 d1 = 784,d3 = 10,中间矩阵转置、相乘、相加过程暂时抽象理解一下。
pytorch学习2_第4张图片

H3作为最后一个输出,要如何计算loss。
最后的Label是0~9,可以让H3的第一维度数字1表示照片数量,第二个1表示是数字“1”。
pytorch学习2_第5张图片
使用one-hot编码,避免数字编码具有大小关系。
pytorch学习2_第6张图片
若H3为[0.1 0.8 0.01 … 0],它与“1”的欧式距离计算如上图。

小结:在这里插入图片描述
H1作为H2的输入,H2作为H3的输入
pred采用十维向量表示,与真实编码数字向量作欧式距离计算,优化这个计算,理论上便能找到最优解。
pytorch学习2_第7张图片

非线性模型

即使通过嵌套线性模型增强了表达能力,但整体模型仍为线性。人脑之所以能很简单地识别出数字样式,是因为人脑有很强的非线性表达能力,对于线性模型来说,很难完成这样的任务。

解决:在每个函数之后添加非线性部分
类似于生物学上的神经元,输出不是多个输入的求和,而是存在阈值,控制输出结果,如relu。pytorch学习2_第8张图片
pytorch学习2_第9张图片

梯度下降解决

在这里插入图片描述
找到一组w,b参数,对于一个新的x,使得其在pred上的映射无线接近于真实值y。

w,b在这里由三组参数构成
pytorch学习2_第10张图片
pytorch学习2_第11张图片
给到一个新的x,在经过三组w,b的线性模型和激活函数的计算后,得到的pred结果是[1,10]的矩阵,其中值的大小表示所在位置索引数字的判断概率大小,通过argmax()函数实现,最终结果输出的是最大0.8概率对应的索引数字“1”。

你可能感兴趣的:(pytorch,学习,人工智能)