pytorch基本操作与基本流程

1、网络构建

2、网络训练

3、网络结构和参数的保存

4、保存文件的重新导入

import torch
import matplotlib.pyplot as plt
import numpy as np

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + x.pow(5) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)


n_input = 6
n_hidden = 12
n_output = 1

# 构建网络架构
net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
# 优化器
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()

# 网络训练
loss_reply = np.empty(50)
for t in range(50):
    prediction = net1(x)  # 前向传播
    loss = loss_func(prediction, y)    # 计算误差
    loss_reply[t] =loss.data.numpy()
    optimizer.zero_grad()  # 梯度清零
    loss.backward()     # 反向传递
    optimizer.step()    # 参数更新优化

# 保存网络结构及其参数  单独保存
torch.save(net1, './model/net.pkl')  # save entire net
torch.save(net1.state_dict(), './model/net_params.pkl')   # save only the parameters

# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

plt.subplot(132)
plt.title('Loss')
plt.plot(loss_reply, 'k-', lw=0.5)



####################################################################################
# 模型导入
net2 = torch.load('./model/net.pkl')
# 参数导入
net2.load_state_dict(torch.load('./model/net_params.pkl'))

# 继续训练

loss_reply2 = np.empty(50)
for t in range(50):
    prediction2 = net2(x)  # 前向传播
    loss = loss_func(prediction2, y)    # 计算误差
    loss_reply2[t] =loss.data.numpy()
    optimizer.zero_grad()  # 梯度清零
    loss.backward()     # 反向传递
    optimizer.step()    # 参数更新优化

plt.subplot(133)
plt.title('reload-Loss')
plt.plot(loss_reply2, 'k-', lw=0.5)
plt.show()

 

你可能感兴趣的:(#,pytorch,python)