Pytorch是目前非常流行的大规模矩阵计算框架,上手简易,文档详尽,最新发表的深度学习领域的论文中有多半是以pytorch框架来实现的,足以看出其易用性和流行度。
这篇文章将以yolov3为例,介绍pytorch中如何实现一个网络的训练和推断。
这一部分主要讲解一下,在pytorch中构建一个深度学习网络,需要包含哪些部分,各部分都起了什么作用。不同的框架的实现方式会有许多不同,但基本都包含这些部分。在以下的讲解中我隐去了一些具体的实现细节,如果想详细了解,可以前往Pytorch-YOLOv3这个github了解,我的讲解代码也是以它为基础改编的,两个版本配合着看能更好地了解和上手。
数据集在网络的训练过程是必须的。通常在训练脚本中,会看到类似下面的这样一行代码。
dataloader = torch.utils.data.DataLoader(Dataset(train_path))
其中的Dataset是自定义的一个类,train_path是训练数据集的路径。Dataset通常定义在命名为datasets的文件内,当然也有以VOCDataset、COCODataset来命名的,其作用都是相同的,定义一个数据集类,以便pytorch调用。下面给出一个Dataset的类定义模板,该模板为yolov3的框架所使用。
class Dataset(Dataset):
def __init__(self, img_dir, label_dir):
self.img_files = glob.glob(os.path.join(img_dir, '*.*'))
self.label_files = glob.glob(os.path.join(label_dir, '*.*'))
def __getitem__(self, index):
# === 图片 ===
# 读取图片
img_path = self.img_files[index % len(self.img_files)].rstrip()
img = np.array(Image.open(img_path))
img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255) # 将numpy.array的格式转为torch.Tensor格式,并转换通道
# 图像预处理(可选)
# 做一些诸如pad、resize之类的操作
# === 标签 ===
# 获取标签文件路径
label_path = self.label_files[index % len(self.img_files)].rstrip()
# 解析标签文件(可选)
# 读取label_path的文件然后解析,也可直接返回label_path
return img, label_path
def __len__(self):
return len(self.img_files)
基本上所有的Dataset类都会包含init、getitem、len这三个函数,在getitem函数中,一般会包含图像预处理和标签预处理,也有些是把这两部分放在外部处理,getitem只获取图像和标签文件路径,值得注意的是,有不少的框架对getitem进行了重载,所以你可能没有找到getitem函数,但是有其他函数能代替getitem的作用。
在DL框架中models是一个最为重要的部分,它实现了整个网络的整体结构和具体细节,在一些通用型的大型项目框架内,通常会把这部分拆分成多个modules进行实现,而在一些小项目里,models也可能仅仅用一个文件来实现。这里我还是以yolov3的models来举例介绍。
# === 读取cfg配置文件 ===
def create_modules(cfg):
# 根据配置文件进行解析
return module_list
# === yolo层定义 ===
class YOLOLayer(nn.Module):
def __init__(self, cfg):
super(YOLOLayer, self).__init__()
def forward(self, x, targets=None):
if targets is not None:
# === 训练阶段 ===
# 计算loss,根据输入的x的结果与targets进行计算,最后得到loss
return x, loss
else:
# === 推断阶段 ===
# 根据输入的x计算出预测结果
return x
# === darknet网络结构定义 ===
class Darknet(nn.Module):
def __init__(self, cfg):
super(Darknet, self).__init__()
self.module_list = create_modules(cfg)
def forward(self, x, targets=None):
losses= []
for module in self.module_list:
if module is not 'YOLO':
x = module(x)
else:
# === 训练阶段 ===
if is_training:
x, loss = module(x, targets)
losses.append(loss)
# === 推断阶段 ===
else:
x = module(x)
return x, losses
在网络结构和yolo层定义中,init和forward这两个函数是必须的,事实上这两个函数也是torch内置已经定义过了的,这里这样写实际是重载了这两个函数。有些项目中可能会把训练和推断的forward函数拆分成两个函数,函数名字也改变了,实际运用时要注意。
训练脚本可以说是网络中最为关键的部分,它直接影响了模型的性能和鲁棒性。基本上不同网络的训练脚本均有不同之处,但是均可以达到一定的效果。一个训练脚本一般包含dataloader、optimizer、model三个部分,运用这三个部分构成train迭代循环过程。
# 构建model,模型结构
model = Darknet(model_config_path)
model.apply(weights_init_normal)
model.train()
if cuda:
model = model.cuda()
# 设置dataloader ,数据集加载器
# batch_size根据显存大小调整,shuffle是指是否打乱数据集的读取顺序,num_workers是指用多少个线程读取数据集
dataloader = torch.utils.data.DataLoader(
Dataset(img_dir, label_dir), batch_size=16, shuffle=True, num_workers=4
)
# 设置optimizer,优化器
# 优化器的种类有非常多,建议新手使用Adam,因为这是一个自适应调整学习率的优化器,不需要设置很多参数
# 如果需要精调模型,或者对这方面比较熟练,可以使用SGD+Momentum优化器
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
# 主循环train过程
total_epoch = 10
for epoch in range(total_epoch):
for batch_i, (imgs, targets) in enumerate(dataloader):
# 注意,输入的图像必须进行通道转换,我这里忽略了这个步骤,因为我之前已经在Dataset部分实现了
# 这里的imgs的shape应该为(B, C, H, W),B为batch_size,C为通道,H为高,W为宽
imgs.requires_grad = True # imgs的requires_grad属性必须为True,而targets的requires_grad属性为False(默认为False)
if cuda:
imgs = imgs.cuda()
targets = targets.cuda()
optimizer.zero_grad()
_, loss = model(imgs, targets)
loss.backward()
optimizer.step()
print('epoch:', epoch, 'batch:', batch_i, 'loss:', loss.detach().cpu().numpy())
if epoch % 1 == 0:
torch.save(model.state_dict(), 'backup.pth')
以上就是训练脚本中所包含的基本部分,关于loss的计算,有些project把它放在了forward函数里面,也是没问题的,只要注意进行计算imgs的requires_grad必须为True就可以了。
推断脚本相对于训练来说比较简单,基本上大同小异,只要模型结构没错基本上输出结果都是相同的。
# 加载模型
model = Darknet(model_config_path)
params = torch.load('backup.pth')
model.load_state_dict(params)
model.eval()
# 读取图片和推断
img = cv2.imread(path)
img = torch.from_numpy(img.transpose(2, 0, 1)).float().div(255).unsqueeze(0)
with torch.no_grad():
out, _ = model(img)
# 处理out,例如进行nms和结果显示,该部分省略
以上就是在pytorch中构建模型和训练、推断的主要过程,这个博客主要目的是帮大家理解这个过程,所以对于一些具体实现细节我没有给出,想详细了解的可以去Pytorch-YOLOv3这个github上进行了解,后续我有时间也会公布一个我个人的yolov3的pytorch版本。如有疑问也可以在下面评论,我有空会回复,谢谢。