Pytorch学习日记:线性回归

线性回归也叫regression,它是一个比较简单的模拟线性方程式的模型。线性方程式我们应该都学过,类似这样:
Y = w X + b Y=wX+b Y=wX+b
其中w是系数,b是位移,它是一条笔直的斜线。
Pytorch学习日记:线性回归_第1张图片
那么我们假设给定一条模拟直线的点,每个点偏移这条直线很小的范围,我们要用到随机函数来模拟这个随机的偏移。
首先可以定义一个随机种子,随机种子基本不影响随机数的值,也可以不定义随机种子。随机数值再0~1之间。

例如:torch.manual_seed(1)
设置随机种子为1.

size = 10
0.2*torch.rand(size)

这里我们不打算使用pytorch的随机函数,毕竟numpy中已经提供了随机函数,下面的数据是生成200个X和Y,模拟参数w为0.5.

代码:

import numpy as np
from numpy import random
import matplotlib.pyplot as plt
import torch
X = np.linspace(-1,1,200)  #在指定的间隔内返回均匀间隔的数字
Y = 0.5 * X + 0.2 * np.random.normal(0,0.05,(200,))  #np.random.normal:产生0前后0.05之间正态分布
# 的200个随机值的array

plt.scatter(X,Y)
plt.show()
#将X, Y转成200 batch大小,1维度的数据
X = torch.Variable(torch.Tensor(X.reshape((200,1))))
Y = torch.Variable(torch.Tensor(Y.reshape((200,1))))

图形:

Pytorch学习日记:线性回归_第2张图片
注意:这里将输入数据转换成(batch_size,dim)格式的数据,添加了一个批次的维度。
现在的任务是给定这些散列点(x,y)对,模拟出这条直线来。这是一个简单的线性模型,我们先用一个简单的1->1的Linear层试试看。

示例代码:

import torch
import numpy as np
import matplotlib.pyplot as plt
from numpy import random

#plot code
X = np.linspace(-1,1,200) # 从-1~1中等间隔输出200个值
Y = 0.5 * X + 0.2 * np.random.normal(0,0.05,(200,))
# plt.scatter(X,Y)
# plt.show()
#将X,Y转成200 batch大小,1维度的数据
X = torch.tensor(torch.Tensor(X.reshape(200,1)),requires_grad=True)
Y = torch.tensor(torch.Tensor(Y.reshape(200,1)),requires_grad=True)
print("Variable(X):",X)


model = torch.nn.Sequential(torch.nn.Linear(1,1),)  #
optimizer = torch.optim.SGD(model.parameters(),lr=0.5)
loss_function = torch.nn.MSELoss()

# training_code
for i in range(300):
    prediction = model(X)
    loss = loss_function(prediction, Y)
    optimizer.zero_grad()  # initialize optimizer
    loss.backward()
    optimizer.step() # step:实现优化器,用来更新优化器的参数

#plot code
plt.figure(1,figsize=(10,3))
plt.subplot(131)
plt.title('model')
plt.scatter(X.data.numpy(), Y.data.numpy())
print("X.data.numpy():",X.data.numpy())
plt.plot(X.data.numpy(),prediction.data.numpy(),'r-',lw = 5)
plt.show()

图形:
Pytorch学习日记:线性回归_第3张图片

本文是作者在学习书籍《Pytorch深度学习实战》一书时做的笔记,才疏学浅,如有错误还望指正。

你可能感兴趣的:(python,深度学习,机器学习,人工智能,numpy)