1、网络构建
2、网络训练
3、网络结构和参数的保存
4、保存文件的重新导入
import torch
import matplotlib.pyplot as plt
import numpy as np
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + x.pow(5) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
n_input = 6
n_hidden = 12
n_output = 1
# 构建网络架构
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 优化器
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
# 网络训练
loss_reply = np.empty(50)
for t in range(50):
prediction = net1(x) # 前向传播
loss = loss_func(prediction, y) # 计算误差
loss_reply[t] =loss.data.numpy()
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传递
optimizer.step() # 参数更新优化
# 保存网络结构及其参数 单独保存
torch.save(net1, './model/net.pkl') # save entire net
torch.save(net1.state_dict(), './model/net_params.pkl') # save only the parameters
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.subplot(132)
plt.title('Loss')
plt.plot(loss_reply, 'k-', lw=0.5)
####################################################################################
# 模型导入
net2 = torch.load('./model/net.pkl')
# 参数导入
net2.load_state_dict(torch.load('./model/net_params.pkl'))
# 继续训练
loss_reply2 = np.empty(50)
for t in range(50):
prediction2 = net2(x) # 前向传播
loss = loss_func(prediction2, y) # 计算误差
loss_reply2[t] =loss.data.numpy()
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传递
optimizer.step() # 参数更新优化
plt.subplot(133)
plt.title('reload-Loss')
plt.plot(loss_reply2, 'k-', lw=0.5)
plt.show()