用pytorch建立一个简单的线性回归模型

本文通过pytorch建立一个简单的线性回归模型, 主要用于理解pytorch的基本训练流程.

如下面的线性回归问题:

import numpy as np 
import matplotlib.pyplot as plt 

X_train = np.arange(10, dtype='float32').reshape((10, 1))
y_train = np.array([1.0, 1.5, 2.8, 4, 3, 5.2, 7.2, 6.9, 8.5, 9.5], dtype='float32')

plt.plot(X_train, y_train, 'o', markersize=8)
plt.xlabel('x')
plt.ylabel('y')
plt.show()

用pytorch建立一个简单的线性回归模型_第1张图片 

首先, 需要把数据变成 pytorch可读取的方式: 

from torch.utils.data import TensorDataset, DataLoader

#标准化
X_train_norm = (X_train- np.mean(X_train)) / np.std(X_train)
X_train_norm = torch.from_numpy(X_train_norm)
y_train = torch.from_numpy(y_train)
train_ds = TensorDataset(X_train_norm, y_train)
batch_size = 1 
train_dl= DataLoader(train_ds, batch_size, shuffle=True)

接着需要定义 loss function, model, 以及 optimizer

import torch.nn as nn 
import torch 

loss_fn = nn.MSELoss(reduction='mean')
model = nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

loss function 用的是torch.nn下面的 mean squar error

model 用的线性模型: inputsize outputsize 为1

optimizer 用的SGD, learning rate=0.001

最后训练模型

num_epochs = 200
for epoch in range(num_epochs):
    for x_batch, y_batch in train_dl:
        #1.产生 输出
        pred = model(x_batch)[:, 0]
        #2. 计算 损失
        loss = loss_fn(pred, y_batch)
        #3. 计算梯度
        loss.backward()
        #4. 更新参数
        optimizer.step()
        #5.梯度reset to 0 
        optimizer.zero_grad()

 来看看训练得到的参数:

print(model.weight.item(), model.bias.item())

2.690967321395874 4.879624843597412

可视化:

X_test = np.linspace(0, 9, num=100, dtype='float32').reshape(-1, 1)
X_test_norm = (X_test-np.mean(X_train))/np.std(X_train)
X_test_norm = torch.from_numpy(X_test_norm) 
y_pred = model(X_test_norm).detach().numpy()
fig = plt.figure(figsize=(12, 5))
ax = fig.add_subplot(1, 1, 1)
plt.plot(X_train_norm, y_train, 'o', markersize= 8)
plt.plot(X_test_norm, y_pred, '--', lw=2)
plt.legend(['Train examples', 'Linear regress'], fontsize=12)
ax.set_xlabel('x', size=12)
ax.set_ylabel('y', size=12)
plt.show()

用pytorch建立一个简单的线性回归模型_第2张图片

 参考自:  Machine Learning with PyTorch and Scikit-Learn Book  by Sebastian Raschka

你可能感兴趣的:(pytorch,线性回归,深度学习)