Mxnet中的Gluoncv提供标准VOC和COCO数据集上的预训练模型、数据读取类和训练程序,如果我们想使用model_zoo里面的预训练模型,并在自己的数据集上微调,则需要调整一些程序,下面介绍在自有的VOC格式的数据集上训练Yolov3的方法。
主干网络仍使用Gluoncv提供的官方Yolov3训练程序,下面链接中的train_yolo3.py:
https://gluon-cv.mxnet.io/model_zoo/detection.html#yolo-v3
官方程序中给出了读取VOC和COCO数据集的方法,假如我们已经有类VOC或者COCO数据集的自有数据集,可以通过继承Gluoncv中的数据读取类的方式来读取自己的数据集,并且保留所有属性。原始训练文件train_yolo3中的get_dataset函数如下所示,其中使用了VOCDetection和COCODetection类,我们以VOC为例。
def get_dataset(dataset, args):
if dataset.lower() == 'voc':
train_dataset = gdata.VOCDetection(
splits=[(2007, 'trainval'), (2012, 'trainval')])
val_dataset = gdata.VOCDetection(
splits=[(2007, 'test')])
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
#省略COCO和其它处理
return train_dataset, val_dataset, val_metric
VOC采用的数据读取为VOCDetection类,定义方式如下所示,可以看到其中定义了类别name,并且有指定数据路径(root),因此我们可以继承VOCDetection类来满足需求。
class VOCDetection(VisionDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'voc'),
splits=((2007, 'trainval'), (2012, 'trainval')),
transform=None, index_map=None, preload_label=True):
super(VOCDetection, self).__init__(root)
……
@property
def classes(self):
"""Category names."""
try:
self._validate_class_names(self.CLASSES)
except AssertionError as e:
raise RuntimeError("Class names must not contain {}".format(e))
return type(self).CLASSES
首先我们在训练文件中需要import VOCDetection类,假设我们自己的VOC格式的数据集种类为['1', '2','3', '4'],可以按照如下方式定义VOCLike类读取自己的数据。
from gluoncv.data import VOCDetection
classes_name = ['1', '2','3', '4']
class VOCLike(VOCDetection):
CLASSES = classes_name
def __init__(self, root, splits, transform=None, index_map=None, preload_label=True):
super(VOCLike, self).__init__(root, splits, transform, index_map, preload_label)
然后将训练文件中的get_dataset函数修改为如下所示,root参数可指定数据集路径,可根据需求指定splits中的名字,如splits=[(2007, 'trainval'), (2012, 'trainval')]等方式。
def get_dataset(dataset, args):
if dataset.lower() == 'voc':
train_dataset = VOCLike(root='VOCdevkit', splits=((2007, 'trainval'),))
val_dataset = VOCLike(root='VOCdevkit', splits=((2007, 'test'),))
val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=classes)
elif dataset.lower() == 'coco':
#……
if args.num_samples < 0:
args.num_samples = len(train_dataset)
if args.mixup:
from gluoncv.data import MixupDetection
train_dataset = MixupDetection(train_dataset)
return train_dataset, val_dataset, val_metric
2. 修改网络输出类别
Gluoncv提供在VOC和COCO上的预训练模型,因此我们可以方便地使用预训练模型在自己的数据集上微调参数。官方提供两种方法,第一种为get_model下VOC预训练模型,然后通过reset_class设置成自己需要的类别;第二种使用get_model中的cuctom直接设置成需要的类别,并复用VOC参数。
(1)使用VOC然后reset_class
train_yolo3训练程序中有get_model部分,原始代码如下:
if args.syncbn and len(ctx) > 1:
net = get_model(net_name, pretrained_base=True, norm_layer=gluon.contrib.nn.SyncBatchNorm, norm_kwargs={'num_devices': len(ctx)})
async_net = get_model(net_name, pretrained_base=False)
else:
net = get_model(net_name, pretrained_base=True)
async_net = net
其中get_model的参数pretrained_base表示加载imagenet上预训练的基础主干网络参数,如果想加载VOC上训好的检测模型参数,测需要将pretrained_base改为pretrained,在加载模型之后,需要使用reset_class函数更改预训练模型以满足自己数据集类别。
if args.syncbn and len(ctx) > 1:
net = get_model(net_name, pretrained=True, norm_layer=gluon.contrib.nn.SyncBatchNorm, norm_kwargs={'num_devices': len(ctx)})
async_net = get_model(net_name, pretrained=True)
else:
net = get_model(net_name, pretrained=True)
async_net = net
net.reset_class(classes_name)
async_net.reset_class(classes_name)
(2)使用cuctom,根据官方finetune_detection.py示例,效果和上面一样(未测试)。
net = gcv.model_zoo.get_model(net_name, classes= classes_name, pretrained_base=False, transfer='voc')
cuctom定义如下,可以发现实现和上面一样,也是根据transfer参数get_model预训练模型,然后reset_class。
from ...model_zoo import get_model
net = get_model(
'yolo3_mobilenet0.25_' +
str(transfer),
pretrained=True,
**kwargs)
reuse_classes = [x for x in classes if x in net.classes]
net.reset_class(classes, reuse_weights=reuse_classes)
reset_class中的resue_weights参数,可以根据需要修改输出层类别数量,并且选择复用的参数。比如我们自己的数据集中有部分种类和VOC一样,那这部分输出层参数就可以直接复用VOC预训练模型的参数,其它不一样的类的输出分支重新初始化。再比如我们现在只想要一个检测行人的模型,如果模型中包含其它无用类,会使得模型有冗余,我们可以使用resue_weights参数将reset_class后的输出层直接复用VOC训练参数,将模型修改为只检测行人而不用重新微调。