(一)使用pytorch搭建模型并训练
为了学习一下使用TensorRT进行推理的全过程,便想着写一个TensorRT推理手写数字分类的小例程。这个例程包括使用pytorch进行LeNet网络的搭建、训练、保存pytorch格式的模型(pth)、将模型(pth)转为onnx通用格式、使用tensorRT解析onnx模型进行推理等。
本节介绍使用pytorch进行手写数字分类网络的搭建,并进行训练。
结构说明:输入是单通道的12828的灰度图像,经过卷积、池化、卷积、池化后shape变为5044(50为通道数)。将其展平后维度为1*800,然后连接一个维度为500的线性层C5,C5层的输出经过ReLU函数激活后再连接一个维度为10的线性层C6,C6层的输出就为网络的输出。
一般来说,我们要求的是输入图片属于某一类的概率,所有我们要将C6的输出通过softmax函数进行转换。
新建model.py文件,包含以下代码:
# 搭建网络模型
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary #用来打印网络层的信息
# from torchkeras import summary module 'torch.backends' has no attribute 'mps'
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
self.fc1 = nn.Linear(800, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.max_pool2d(self.conv1(x), kernel_size=2, stride=2)
x = F.max_pool2d(self.conv2(x), kernel_size=2, stride=2)
x = x.view(-1, 800) # 将其展平
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1) # 要使用NLLLoss()损失函数,所以输出要先经过log_softmax
if __name__ == "__main__":
net = Net()
summary(net, (1,1,28,28))
新建train.py,包含以下代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from model import Net
import numpy as np
import os
import torch.utils.data
from random import randint
class MnistModel(object):
def __init__(self):
self.batch_size = 64 # 训练batch_size
self.test_batch_size = 100 # 测试batch_size
self.learning_rate = 0.0025 #学习率
self.sgd_momentum = 0.9
self.log_interval = 100
# 构造数据
self.train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"./tmp/mnist/data",
train=True,
download=True,
transform=transforms.Compose( # 预处理:对训练数据只进行标准化
[transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081,))])
),
batch_size=self.batch_size,
shuffle=True,
num_workers=4,
timeout=600,
)
self.test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"./tmp/mnist/data",
train=False,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307),(0.3081,))])
),
batch_size = self.test_batch_size,
shuffle = True,
num_workers=4,
timeout=600,
)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.network = Net()
self.network.to(self.device)
def learn(self, num_epochs=2): # 训练两个epoch
#
# Train the network for a single epoch
def train(epoch):
self.network.train()
optimizer = optim.SGD(self.network.parameters(), lr=self.learning_rate, momentum=self.sgd_momentum) # 使用SGD优化器
for batch, (data, target) in enumerate(self.train_loader):
data, target = Variable(data.to(self.device)), Variable(target.to(self.device))
optimizer.zero_grad()
output = self.network(data)
loss = F.nll_loss(output, target).to(self.device)
loss.backward()
optimizer.step()
if batch % self.log_interval == 0: #每100个batch打印一次信息
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch * len(data),
len(self.train_loader.dataset),
100.0 * batch / len(self.train_loader),
loss.data.item(),
)
)
# Test the network
def test(epoch):
self.network.eval()
test_loss = 0
correct = 0
for data, target in self.test_loader:
with torch.no_grad():
data, target = Variable(data.to(self.device)), Variable(target.to(self.device))
output = self.network(data)
test_loss += F.nll_loss(output, target).data.item()
pred = output.data.max(1)[1] # 输出最大值的索引为预测的类别
correct += pred.eq(target.data).cpu().sum()
test_loss /= len(self.test_loader)v # 测试集每一个batch的平均损失
print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(self.test_loader.dataset), 100.0 * correct / len(self.test_loader.dataset)
)
)
for e in range(num_epochs):
train(e + 1)
test(e + 1)
train_model = MnistModel()
train_model.learn()
总的来说,训练代码中没有太值得让人注意的地方。如果非要说有,那我觉得以下三点可能是要注意的地方:
可以看到,两个epoch后,模型的准确率为99%,所以我选择停止训练,然后保存模型。
在pytorch中保存模型有两种形式,一种是保存整个网络,一种是只保存网络中的参数。
保存整个网络的方法如下:
# 保存整个网络
torch.save(net, path)
# 加载网络
model = torch.load(path)
只保存网络中的参数的方法如下:
# 保存
torch.save(net.state_dict(), path)
# 加载
model = model.load_state_dict(torch.load(path))
在这个demo中,我们只需要在train.py后加上
torch.save(net.state_dict(), './model.pth')
就可以保存模型为model.pth文件。
本节我们进行了模型的搭建、训练以及保存模型。下一节我们将介绍如何将我们保存的pth文件转为onnx通用格式,同时对我们转成的onnx文件进行检查和验证。