pytorch从入门到精通系列之线性回归

2019独角兽企业重金招聘Python工程师标准>>> hot3.png

之前的文章有提到,pytorch的入门让人无比愉悦的。但毕竟是深度学习,我说的门槛是和tensorflow比,本身深度学习对于微积分,线性代数和概率论还是有一定要求的。当然,还是那句话,其实深度学习涉及的数学原理,比什么SVM,CRF要简单太多,而且深度学习的变化和应用场景比前者要多得多,所以学习深度学习是很有意义的。

pytorch从入门到精通系列之线性回归_第1张图片

很多官方或非官方的demo,都是长篇大论,因为要演示一个例子,需要准备数据,其实大部分的代码是在准备数据,预处理,与深度学习本身没有太大关系。这样看起来不直观,把初学者都给吓住了。所以,打算写一个入门系列,就是手动造一些简单的数据,就是为了演示用。

 

本文要实现线性函数的拟合,就是比如y = 2*x+3这样的线性函数。首先标准的步骤,导入torch和Variable。

import torch
from torch.autograd import Variable

然后准备简单的数据,x=1,2,3, y = 2,4,6, 线性关系就是y=2*x。注意这里的tensor是3行一列,然后转为Variable。

x_data = Variable(torch.Tensor([[1.0],[2.0],[3.0]]))
y_data = Variable(torch.Tensor([[2.0],[4.0],[6.0]]))

然后就可以定义网络了,所有的模型都需要继承自nn.Module。然后定义一个线性层,参数为1,1。实现forward函数,用作训练迭代。

import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear = nn.Linear(1,1)

    def forward(self,x):
        y_pred = self.linear(x)
        return y_pred

直接实例化模型,并定义损失函数,这里按我们高数里学的,使用最小二乘法。然后使用随机梯度下降对所有可优化参数进行优化,lr是learning_rate学习速率,也就是梯度变化步长。

model = Model()
criterion = nn.MSELoss(size_average=False)
params = model.parameters()
optimizer = torch.optim.SGD(params=params,lr=0.01)

开发循环迭代500次,用x_data通过model得到y_pred,也是一个3x1的矩阵。然后通过criterion求loss值。

for epoch in range(500):
    y_pred = model(x_data)#3x1
    loss = criterion(y_pred,y_data)
    #这里可以把loss打印出来看,值都是一样的,因为我们没有优化参数。使用data[0]是因为loss.data是一个list
    print(epoch,loss.data[0])

    #然后是三步标准的,导数归0,反向求导,反向传播  
optimizer.zero_grad()    
    loss.backward()
    optimizer.step()
这里值得注意就是每个epoch,梯度都要归零。
print('训练完成!')
x_test = Variable(torch.Tensor([[4.0]]))
y_test = model(x_test)
#y_test是一个Variable,包括.data是一个FloatTensor,size=1x1
print('x_test:%d,预测值为%f'%(x_test.data[0][0],y_test.data[0][0]))

最后使用训练好的模型去测试,看下结果:

训练完成!

x_test:4,预测值为7.975748

 

关于作者:魏佳斌,互联网产品/技术总监,北京大学光华管理学院(MBA),特许金融分析师(CFA),资深产品经理/码农。偏爱python,深度关注互联网趋势,人工智能,AI金融量化。致力于使用最前沿的认知技术去理解这个复杂的世界。

扫描下方二维码,关注:AI量化实验室(ailabx),了解AI量化最前沿技术、资讯。

pytorch从入门到精通系列之线性回归_第2张图片

转载于:https://my.oschina.net/u/1996852/blog/1576226

你可能感兴趣的:(数据结构与算法,python,人工智能)