pytorch: 学习笔记5, pytorch实现线性回归2(简单实现)

pytorch实现线性回归简单实现
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.utils.data as Data

# 显示数据
def show(sample, labels):
    print('show')
    plt.scatter(sample, labels, 1)
    plt.show()

# 定义模型 1
class LinearNet(torch.nn.Module):
    def __init__(self, n_feature):
        super(LinearNet, self).__init__()
        self.linear = torch.nn.Linear(n_feature, 1)

    # forward 定义前向传播
    def forward(self, x):
        y = self.linear(x)
        return y


def main():
    # 1, 生成 1000*2 数据
    num_inputs = 2
    num_examples = 1000
    true_w = [2, -3.4]
    true_b = 4.2
    features = torch.randn(num_examples, num_inputs, dtype=torch.float32)
    labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
    labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)
    # show(features[:, 0].numpy(), labels.numpy())
    # show(features[:, 1].numpy(), labels.numpy())

    # 2, 按照batch_size读取数据
    batch_size = 10
    # 将训练数据的特征和标签组合
    dataset = Data.TensorDataset(features, labels)  # PyTorch提供了data包来读取数据
    # 随机读取小批量
    data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)
    # 查看一下读取的数据:
    # for X, y in data_iter:
    #     print(X, y)
    #     break

    # 3.1, 定义模型
    net = LinearNet(num_inputs)
    print(net)  # 使用print可以打印出网络的结构
    for param in net.parameters():  # 通过net.parameters()来查看模型所有的可学习参数,此函数将返回一个生成器。
        print(param)
    # # 3.2, 定义模型 method2: Sequential是一个有序的容器,网络层将按照在传入Sequential的顺序依次被添加到计算图中
    # net2 = torch.nn.Sequential(
    #     torch.nn.Linear(num_inputs, 1)
    #     # 此处还可以传入其他层
    #     )
    # # 3.3, 定义模型 method3
    # net3 = torch.nn.Sequential()
    # net3.add_module('linear', torch.nn.Linear(num_inputs, 1))
    # # net3.add_module ......
    # print(net2, '\n', net3)

    # 4。1, 初始化模型参数: 通过init.normal_将权重参数每个元素初始化为随机采样于均值为0、标准差为0.01的正态分布。偏差会初始化为零
    torch.nn.init.normal_(net.linear.weight, mean=0, std=0.01)  # net 的成员变量 linear
    torch.nn.init.constant_(net.linear.bias, val=0)  # 也可以直接修改bias的data: net[0].bias.data.fill_(0)
    # 4.2 初始化模型参数 对应3.2: net[0]这样根据下标访问子模块的写法只有当net是个ModuleList或者Sequential实例时才可以
    # torch.nn.init.normal_(net2[0].weight, mean=0, std=0.01)  # net 的成员变量 linear
    # torch.nn.init.constant_(net2[0].bias, val=0)  # 也可以直接修改bias的data: net[0].bias.data.fill_(0)

    # 5, 定义损失函数
    loss = torch.nn.MSELoss()  # 均方误差损失

    # 6, 定义优化算法
    optimizer = torch.optim.SGD(net.parameters(), lr=0.03)
    print(optimizer)
    # 调整学习率
    for param_group in optimizer.param_groups:
        param_group['lr'] *= 0.1  # 学习率为之前的0.1倍

    # 7, 训练模型: 通过调用optim实例的step函数来迭代模型参数
    num_epochs = 10
    for epoch in range(1, num_epochs + 1):
        for X, y in data_iter:
            output = net(X)
            l = loss(output, y.view(-1, 1))
            optimizer.zero_grad()  # 梯度清零,等价于net.zero_grad()
            l.backward()
            optimizer.step()
        print('epoch %d, loss: %f' % (epoch, l.item()))

    # 8, 比较学到的模型参数和真实的模型参数
    print(true_w, '\n', net.linear.weight)  # net.linear.weight: 第0层的权重
    print(true_b, '\n', net.linear.bias)  # net.linear.bias: 第0层的偏差


if __name__ == '__main__':
    main()
结果:

LinearNet(
(linear): Linear(in_features=2, out_features=1, bias=True)
)
Parameter containing:
tensor([[-0.2861, 0.2106]], requires_grad=True)
Parameter containing:
tensor([-0.5740], requires_grad=True)
SGD (
Parameter Group 0
dampening: 0
lr: 0.03
momentum: 0
nesterov: False
weight_decay: 0
)
epoch 1, loss: 12.505808
epoch 2, loss: 2.363238
epoch 3, loss: 0.923611
epoch 4, loss: 0.326501
epoch 5, loss: 0.143060
epoch 6, loss: 0.025092
epoch 7, loss: 0.015814
epoch 8, loss: 0.003047
epoch 9, loss: 0.000669
epoch 10, loss: 0.000169
[2, -3.4]
Parameter containing:
tensor([[ 1.9910, -3.3912]], requires_grad=True)
4.2
Parameter containing:
tensor([4.1899], requires_grad=True)

Process finished with exit code 0

参考学习,把学习中的知识整合,并非自己实现。
参考:https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.3_linear-regression-pytorch

你可能感兴趣的:(pytorch,神经网络,python)