前言:
目前我在做车辆目标检测任务,虽然对实时性的要求不高,但是对检测的准确性有比较高的要求.使用yolo ,retinanet 神经网络进行检测的时候发现, 喂数据的多少,很影响检测的结果.不论是做什么任务,数据一直都是一个比较头痛的问题. ssd是一个优秀的网络模型.在数据增强方法做了很多处理,例如裁剪,明亮强度等.我在github上面,找到了ssd源码,https://github.com/amdegroot/ssd.pytorch, 理解数据增强代码逻辑,并对自己的数据集进行增强处理.
augmentation.py 源码理解
源码的文件结构是很清晰的. 路径utils/augmentation.py就是数据增强的代码.这份源码每个函数做的工作其实看函数名就很清楚. 例如class RandomBrightness 就是对图片随机增加亮度.
class RandomBrightness(object):
def __init__(self, delta=32):
assert delta >= 0.0
assert delta <= 255.0
self.delta = delta
def __call__(self, image, boxes=None, labels=None):
if random.randint(2):
delta = random.uniform(-self.delta, self.delta)
image += delta
return image, boxes, labels
我想重点是这个class SSDAugmentation
class SSDAugmentation(object):
def __init__(self, size=300, mean=(104, 117, 123)):
self.mean = mean
self.size = size
self.augment = Compose([
ConvertFromInts(),
ToAbsoluteCoords(),
PhotometricDistort(),
Expand(self.mean),
RandomSampleCrop(),
RandomMirror(),
ToPercentCoords(),
Resize(self.size),
SubtractMeans(self.mean)
])
def __call__(self, img, boxes, labels):
return self.augment(img, boxes, labels)
不论在上面的 class RandomBrightness ,SSDAugmentation,还是文件中的其他 class ,都定义了一个call 方法,我查了这个函数的使用方法,发现是python 的魔法方法,作用是让类的对象也能够作为一个函数被调用.理解了这个魔法方法,我才发现这份代码真的是写的很美,很值得我去学习. 赞叹完了,还有另外一个地方要注意, 那就是self.augmentation = Compose([.....]) , 成员变量 augmentation 是Compose的一个对象,所以调用 SSDAugmentation 的call方法时候,就会执行 Compose 类的call 方法,对图像进行一系列的数据处理.
数据格式转换
在ssd的源码中,支持voc,coco的数据格式,因此我也把自己的数据集提前转成voc的格式.ssd 读取voc数据的代码在data/voc0712.py, 从class VOCDetection这个类开始阅读,便可以知道整个处理流程.
class VOCDetection(data.Dataset):
"""VOC Detection Dataset Object
input is image, target is annotation
Arguments:
root (string): filepath to VOCdevkit folder.
image_set (string): imageset to use (eg. 'train', 'val', 'test')
transform (callable, optional): transformation to perform on the
input image
target_transform (callable, optional): transformation to perform on the
target `annotation`
(eg: take in caption string, return tensor of word indices)
dataset_name (string, optional): which dataset to load
(default: 'VOC2007')
"""
def __init__(self, root,
# image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
image_sets=[('2007', 'trainval')],
transform=None, target_transform=VOCAnnotationTransform(),
dataset_name='VOC0712'):
self.root = root
self.image_set = image_sets
self.transform = transform
self.target_transform = target_transform
self.name = dataset_name
self._annopath = osp.join('%s', 'Annotations', '%s.xml')
self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
self.ids = list()
for (year, name) in image_sets:
rootpath = osp.join(self.root, 'VOC' + year)
for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
self.ids.append((rootpath, line.strip()))
for i in range( len(self.ids)):
self.pull_item(i)
编写自己的脚本
逻辑理解后,代码实现和结果见下.
from utils.augmentations import SSDAugmentation
from data import myvoc0712 as myvoc
from data import config
from scipy import misc
import cv2
import random
import numpy as np
import argparse
parser = argparse.ArgumentParser(
description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
type=str, help='VOC or COCO')
parser.add_argument('--dataset_root', default=myvoc.VOC_ROOT,
help='Dataset root directory path')
args = parser.parse_args()
cfg = config.voc
aug = SSDAugmentation(cfg['min_dim'],config.MEANS)
dataset = myvoc.VOCDetection(root=args.dataset_root,
transform=SSDAugmentation(cfg['min_dim'],
config.MEANS))