使用PyTorch实现线性回归算法

使用torch已有的框架进行搭建

  1. import
import numpy as np
import torch
from torch.utils import data
  1. 人造数据
    使用线性模型参数 w = [2, -3.4]T、b = 4.2和噪声项noise生成数据集及其表现:
    y = Wx + b + noise
def synthetic_data(w, b, num_examples):
    """生成 y = wx + b + noise"""
    X = torch.normal(0,1, (num_examples, len(w))) # 第三个是out的shape
    y = torch.matmul(X, w) + b   ## matmul 比起mm或者mv,会自己判断是m还是v
    y += torch.normal(0, 0.01, y.shape) ## 制造噪声
    return X, y.reshape((-1,1))


true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
  1. 调用框架中现有的API来读取数据
def load_array(data_arrays, batch_size, is_train=True):
    """构造一个PyTorch数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

next(iter(data_iter))

[tensor([[ 1.4412, 0.0381],
[-0.5712, -1.4606],
[ 0.9690, -0.1091],
[-0.1664, -1.0892],
[-0.6292, -0.6229],
[-0.9086, -1.1979],
[ 0.2717, -0.7527],
[-1.1563, 1.5425],
[-0.8354, -0.2888],
[ 0.4858, -0.6182]]), tensor([[ 6.9568],
[ 8.0240],
[ 6.5040],
[ 7.5922],
[ 5.0577],
[ 6.4421],
[ 7.3112],
[-3.3624],
[ 3.5135],
[ 7.2626]])]

  1. 使用框架预定义好的层构造训练算法
from torch import nn ## neural network

net = nn.Sequential(nn.Linear(2,1))  
## 定义框架,nn.Sequential是整体框架,里面放了一个Linear的层
## nn.Linear(in,out), in:输入的dimension,out:输出的dimension
  1. 初始化模型参数
net[0].weight.data.normal_(0, 0.01) ## net[0] 代表第一层
net[0].bias.data.fill_(0)
  1. 定义损失函数
# 计算均方误差使用的是 MSELoss类,也即为平方L2范数

loss = nn.MSELoss()
  1. 定义优化算法
# 实例化SGD实例

trainer = torch.optim.SGD(net.parameters(), lr=0.03)
  1. 数据训练
# 训练代码

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        trainer.zero_grad()  #每次将grad清零
        l.backward()
        trainer.step()  # Performs a single optimization step.
    l = loss(net(features), labels)
    print(f'epoch {epoch +1}, loss {l:f}')

epoch 1, loss 0.000339
epoch 2, loss 0.000112
epoch 3, loss 0.000111

你可能感兴趣的:(pytorch,算法,线性回归)