Pytorch 模型训练步骤

目录

1、导入必要库

2、加载数据

3、构建网络

4、训练模型

5、保存模型参数

        1)、仅仅保存和加载模型参数

        2)、保存和加载整个模型

        3)、保存多个模型参数


1、导入必要库

import torch
from torch import optim, nn
import torch.utils.data as Data

2、加载数据

x = torch.linspace(1, 10, 10)       # x data (torch tensor)
y = torch.linspace(10, 1, 10)       # y data (torch tensor)

# 注意:x的数据类型是 torch.FloatTensor
# y的数据类型是 torch.LongTensor
# x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # FloatTensor = 32-bit floating
# y = torch.cat((y0, y1), ).type(torch.LongTensor)    # LongTensor = 64-bit integer

# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)

# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=3,      # mini batch size
    shuffle=True,               # 要不要打乱数据 (打乱比较好)
    num_workers=0,              # 多线程来读数据
)

3、构建网络

# 定义网络结构 build net
class Net(torch.nn.Module):
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net, self).__init__()
        
        self.fc1 =torch.nn.Linear(n_feature,n_hidden)
        self.fc2 =torch.nn.Linear(n_hidden,n_output)

    # 定义一个前向传播过程函数
    def forward(self, x):
        
        x=F.relu(self.fc1(x))
        x=self.fc2(x)
        return x
# 实例化一个网络为 model
model = Net(n_feature=1,n_hidden=10,n_output=10)
print(model)

4、训练模型

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss() 
 
# 训练模型
model.train()
for epoch in range(5):
    for step, (b_x, b_y) in enumerate(loader): 
        output = model(b_x)
        loss = loss_func(output, b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
# 测试模型
model.eval()
for step, (b_x, b_y) in enumerate(loader):
    output = model(b_x)
    loss = loss_func(output, b_y)
    
     _, pred_y = torch.max(output.data, 1)
    correct = (pred_y == b_y).sum()
    total = b_y.size(0)
    print('Epoch: ', step, '| test loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % (float(correct)/total))

5、保存模型参数

        1)、仅仅保存和加载模型参数

# 保存模型参数
torch.save(model.state_dict(), './path/model.pkl')
# 读取模型参数
model.load_state_dict(torch.load('./path/model.pkl'))

        2)、保存和加载整个模型

# 保存整个模型
torch.save(model,  './path/model.pkl')
# 加载整个模型
model = torch.load('./path/model.pkl')

        3)、保存多个模型参数

# 多个模型参数保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

# 模型参数加载
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

你可能感兴趣的:(一些小代码,参数解释)