首先是包导入、数据生成:
这里做的是线性的拟合,所以准备了一些数据,并定义了一个线性函数。
# 导入文件
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# make data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 进行解压,
y = x.pow(2) + 0.2 * torch.rand(x.size()) # 添加噪声
print (x)
print (x.size())
print (y)
网络的定义:
这里是定义网络,只有一个隐藏层,和输出层,使用 relu 作为激活函数。在__init__中定义了网络的基本元素,在函数fowward中定义了前向计算过程,也就是网络的结构。
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
# 前项传播过程
def forward(self, x):
hidden_x = self.hidden(x)
hidde_x_out = F.relu(hidden_x) # 使用relu 进行激活
out = self.predict(hidde_x_out)
return out
net = Net(n_feature=1, n_hidden=10, n_output=1) # 网络初始化
print (net)
定义优化器、损失函数并进行训练(拟合):
optim = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
plt.ion()
for step in range(1000):
prediction = net(x)
loss = loss_func(prediction, y)
optim.zero_grad() # 清空梯度
loss.backward() # 计算梯度
optim.step() # 应用梯度,并更新参数
if step % 10 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), "r-", lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
torch 中的线性结构可以参考这篇文章:https://blog.csdn.net/u012936765/article/details/52671156
上面介绍了如何定义并使用简单的神经网络,这里给出另外一种更加简单的定义神经网络的方式。可以进行对比:
首先是自定义的神经网络:
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
net1 = Net(1, 10, 1)
然后,重点是下面这种方式:使用提供的简单的方式进行定义
net2 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
两种网络的作用是相同的,只不过定义的方式不同,第二种更加简便些。一般考虑使用第二种的定义方式。
上面介绍了如何简单定义连续序列,并且给出了如何进行训练一个网络。当我们的模型训练完毕之后,如果需要后面进行使用,就涉及到报错和重载,这里进行简单介绍:
模型的保存主要使用了:torch.save() 函数
模型的载入主要使用了:torch.load() 函数
保存模型的时候可以保存整个模型(网络结构,和参数),同样也可只保存参数,这样就能见效模型体积。
只保存参数:torch.save(net1.state_dict(), "net1_params.pkl") # 主要保存的是状态字典
保存模型和参数: torch.save(net1, "net1.pkl") # 这样保存体积比较大
示例代码:(如果已经理解可直接跳过此部分代码看下一小节)
# save or reload model
import torch
import matplotlib.pyplot as plt
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
# save
def save():
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optim = torch.optim.SGD(net1.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for step in range(200):
predict = net1(x)
loss = loss_func(predict, y)
optim.zero_grad()
loss.backward()
optim.step()
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title("Net1")
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), predict.data.numpy(), c="red", lw=2)
# svae
torch.save(net1, "net.pkl")
torch.save(net1.state_dict(), "net_params.pkl")
# restore net
def restore_net():
net2 = torch.load("net.pkl")
prediction = net2(x)
plt.subplot(132)
plt.title("Net2")
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), c="red", lw=2)
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
state_dict = torch.load("net_params.pkl")
print(state_dict)
net3.load_state_dict(state_dict)
prediction = net3(x)
plt.subplot(133)
plt.title("Net3")
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), c= "red", lw=2)
save()
restore_net()
restore_params()
如果批量加载数据可以使用torch提供的torch.utils.data 包。
具体的步骤:
例如:
加载数据:
BATCH_SIZE = 4
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
print(x, y)
torch_dataset = Data.TensorDataset(x, y)
print(torch_dataset)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
使用数据:
print(len(loader))
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print(
'Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy()
)