PyTorch实现线性回归

概念

线性回归是分析一个变量与另外一(多)个变量之间关系的方法

  • 因变量:y
  • 自变量:x
  • 关系:线性
  • 表达式:y = wx + b
  • 目的:求解w和b

求解步骤:

  1. 确定模型
    Model:y = wx + b
  2. 选择损失函数
    均方差MSE: 1 m ∑ i = 1 m ( y i − y i ^ ) 2 \frac{1}{m}\sum_{i=1}^{m}(y_i - \hat{y_i})^2 m1i=1m(yiyi^)2
  3. 求解梯度并更新w,b
    梯度下降法:
    w = w – LR * w.grad
    b = b – LR * w.grad
    LR为步长,学习率
import torch
import matplotlib.pyplot as plt

torch.manual_seed(10)  # 初始化随机数种子,保证结果可以复现
lr = 0.1  # 学习率
# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1))  # torch.randn(20, 1)加入噪声
# 初始化w和b
w = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 开始迭代
for i in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pre = torch.add(wx, b)  # 预测值
    # 计算损失
    loss = (0.5 * (y - y_pre) ** 2).mean()  # 乘以0.5是为了求导过程中消除平方2的影响,mean()求均值
    # 反向传播
    loss.backward()  # 自动求导,得到梯度
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    # 绘图
    if loss.data.numpy() < 1:
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), y_pre.data.numpy(), "r-", lw=5)
        plt.text(2, 10, "loss=%.4f" % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.xlim(1.5, 10)
        plt.ylim(8, 28)
        plt.title("i:{}  w:{}  b:{}".format(i, w.data.numpy(), b.data.numpy()))
        plt.pause(0.5)
        break

PyTorch实现线性回归_第1张图片

你可能感兴趣的:(PyTorch)