代码来自:链接
1、在pytorch 中继承自己数据集的常用框架
class Yolo_dataset(Dataset):
def __init__(self, label_path, cfg, train=True):
def __getitem__(self, index):
def __len__(self):
train_dataset = Yolo_dataset(config.train_label, config, train=True)
val_dataset = Yolo_dataset(config.val_label, config, train=False)
train_loader = DataLoader()
val_loader = DataLoader()
2、重点在__getitem__(self, index)
2.1、pytorch 读取一帧图片,获取相对应的boxes
if not self.train:
return self._get_val_item(index)
img_path = self.imgs[index]
bboxes = np.array(self.truth.get(img_path), dtype=np.float)
img_path = os.path.join(self.cfg.dataset_dir, img_path)
2.2、需要读取use_mixup+1帧图片
for i in range(use_mixup + 1):
if i != 0:
img_path = random.choice(list(self.truth.keys()))
bboxes = np.array(self.truth.get(img_path), dtype=np.float)
img_path = os.path.join(self.cfg.dataset_dir, img_path)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if img is None:
continue
2.3、为数据增强生成相对应的随机数据
oh, ow, oc = img.shape
#cfg.jitter = 0.2 , 是指抖动原始图片长宽高的0.2
dh, dw, dc = np.array(np.array([oh, ow, oc]) * self.cfg.jitter, dtype=np.int)
#Cfg.hue = .1 生成 -0.1~0.1之间的随机数
dhue = rand_uniform_strong(-self.cfg.hue, self.cfg.hue)
#Cfg.saturation = 1.5 有0.5的概率,生成1~1.5之间的数,有0.5的概率生成0到1之间的数
dsat = rand_scale(self.cfg.saturation)
#同上
dexp = rand_scale(self.cfg.exposure)
#x轴向左或向右最多抖动宽的0.2倍
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
# y轴向上或向下最多抖动高的0.2倍
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
flip = random.randint(0, 1) if self.cfg.flip else 0
if (self.cfg.blur):
tmp_blur = random.randint(0, 2) # 0 - disable, 1 - blur background, 2 - blur the whole image
if tmp_blur == 0:
blur = 0
elif tmp_blur == 1:
blur = 1
else:
blur = self.cfg.blur
if self.cfg.gaussian and random.randint(0, 1):
gaussian_noise = self.cfg.gaussian
else:
gaussian_noise = 0
2.4、letter_box是保持图片本身的尺寸
因为我们最终要将图片处理为网络输入的大小,下面函数就是处理在resize的过程中,怎样保持原始图片的比例
if self.cfg.letter_box:
# 原始图片的宽高比
img_ar = ow / oh
#输入网络的宽高比
net_ar = self.cfg.w / self.cfg.h
#原始图片的宽高比 与 输入网络的宽高比
result_ar = img_ar / net_ar
if result_ar > 1: # sheight - should be increased
#说明原始图片比较宽,如果暴力resize到网络的大小,宽需要压缩好多
#固定宽,为了保证网络长宽比,宽应该是多少
oh_tmp = ow / net_ar
#为了得到上面计算的宽的值,需要各向上、向下移动多少
delta_h = (oh_tmp - oh) / 2
#ptop 和 pbot是抖动时,高需要上下移动的值
#delta_h,delta_h是保证原始图片的宽高比下,再需要移动的值。
ptop = ptop - delta_h
pbot = pbot - delta_h
else: # swidth - should be increased
ow_tmp = oh * net_ar
delta_w = (ow_tmp - ow) / 2
pleft = pleft - delta_w
pright = pright - delta_w
2.5、下面是着重理解数据增强中的jitter是如何工作的,假设图像的宽高为ow和oh,那么可以这样写[0, 0, ow, oh], 上面我们计算了pleft,pright,ptop,pbot,其实就是左上角(0,0) 向中心点移动pleft和ptop,坐标变为(0+pleft, 0+ ptop),同理右下角 (ow, oh)向中心点移动pright和pbot之后,坐标变为(ow - pright, oh - pbot),即抖动之后
[0, 0, ow, oh] 变为 [pleft, ptop, ow - pright, oh - pbot]
所以新的宽高变为 :
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
进入到函数 image_data_augmentation中
计算[0, 0, ow, oh] 和 [pleft, ptop, ow - pright, oh - pbot]的交集new_src_rect
oh, ow, _ = img.shape
pleft, ptop, swidth, sheight = int(pleft), int(ptop), int(swidth), int(sheight)
# crop
src_rect = [pleft, ptop, swidth + pleft, sheight + ptop] # x1,y1,x2,y2
img_rect = [0, 0, ow, oh]
new_src_rect = rect_intersection(src_rect, img_rect) # 交集
这个交集new_src_rect其实就是我们在原始图片img中选中的数据块
dst_rect 是在新的数据块上的坐标
dst_rect = [max(0, -pleft), max(0, -ptop), max(0, -pleft) + new_src_rect[2] - new_src_rect[0],max(0, -ptop) + new_src_rect[3] - new_src_rect[1]]
生成新的图像数据
cropped = np.zeros([sheight, swidth, 3])
cropped[:, :, ] = np.mean(img, axis=(0, 1))
cropped[dst_rect[1]:dst_rect[3], dst_rect[0]:dst_rect[2]] = img[new_src_rect[1]:new_src_rect[3], new_src_rect[0]:new_src_rect[2]]
具体过程可以参考下图:
2.6、下面着重理解数据增强中的mosaic,主要是函数
out_img, out_bbox = blend_truth_mosaic(out_img, ai, truth.copy(), self.cfg.w, self.cfg.h, cut_x, cut_y, i, left_shift, right_shift, top_shift, bot_shift)
out_img是数据处理后的输出,大小是cfg.w x self.cfg.h x 3
ai 是原始图片数据增强之后的输出
truth.copy() 里面放的是ground truth boxes
上面的黑色框输出结果,是mosaic的输入结果,看过程很简单,如下图:
mosaic就是要把四张图片拼接到一张图片,大小是网络输入的大小,
第一步,需要生成一个空的模板图,用于承载输出的数据
out_img = np.zeros([self.cfg.h, self.cfg.w, 3])
第二步,生成中心的分割点(cut_x,cut_y)
在配置文件中,
if Cfg.mosaic and Cfg.cutmix:
Cfg.mixup = 4
elif Cfg.cutmix:
Cfg.mixup = 2
elif Cfg.mosaic:
Cfg.mixup = 3
我使用的是Cfg.mosaic = 1, Cfg.cutmix = 0, 所以Cfg.mixup = 3
use_mixup = 3, 需要4张图片拼成416x416,cut_x和cut_y将416x416分成四个格子,后面就是要提取四张图片放进进去,即mosaic
use_mixup = self.cfg.mixup
if random.randint(0, 1):
use_mixup = 0
if use_mixup == 3:
min_offset = 0.2
cut_x = random.randint(int(self.cfg.w * min_offset), int(self.cfg.w * (1 - min_offset)))
cut_y = random.randint(int(self.cfg.h * min_offset), int(self.cfg.h * (1 - min_offset)))
第三步,计算处理好的图像数据的分割点
left_shift = int(min(cut_x, max(0, (-int(pleft) * self.cfg.w / swidth))))
top_shift = int(min(cut_y, max(0, (-int(ptop) * self.cfg.h / sheight))))
right_shift = int(min((self.cfg.w - cut_x), max(0, (-int(pright) * self.cfg.w / swidth))))
bot_shift = int(min(self.cfg.h - cut_y, max(0, (-int(pbot) * self.cfg.h / sheight))))
第四步,开始依次读取四张图片,然后切割需要的数据,向out_img填充数据,这里重点说下,left_shift, top_shift, right_shift,bot_shift ,因为这个关系到如何在img上定位开始和结束坐标
首先,第一帧定位的是top_shift:top_shift + cut_y, left_shift:left_shift + cut_x, left_shift和top_shift是真正的图像数据开始的地方,
left_shift = min(left_shift, w - cut_x)
top_shift = min(top_shift, h - cut_y)
right_shift = min(right_shift, cut_x)
bot_shift = min(bot_shift, cut_y)
if i_mixup == 0:
bboxes = filter_truth(bboxes, left_shift, top_shift, cut_x, cut_y, 0, 0)
out_img[:cut_y, :cut_x] = img[top_shift:top_shift + cut_y, left_shift:left_shift + cut_x]
if i_mixup == 1:
bboxes = filter_truth(bboxes, cut_x - right_shift, top_shift, w - cut_x, cut_y, cut_x, 0)
out_img[:cut_y, cut_x:] = img[top_shift:top_shift + cut_y, cut_x - right_shift:w - right_shift]
if i_mixup == 2:
bboxes = filter_truth(bboxes, left_shift, cut_y - bot_shift, cut_x, h - cut_y, 0, cut_y)
out_img[cut_y:, :cut_x] = img[cut_y - bot_shift:h - bot_shift, left_shift:left_shift + cut_x]
if i_mixup == 3:
bboxes = filter_truth(bboxes, cut_x - right_shift, cut_y - bot_shift, w - cut_x, h - cut_y, cut_x, cut_y)
out_img[cut_y:, cut_x:] = img[cut_y - bot_shift:h - bot_shift, cut_x - right_shift:w
大致操作流程如下:
图片id | 原图 | 数据增强后 | mosaic切割出来的数据块 |
---|---|---|---|
1 | |||
2 | |||
3 | |||
4 |
最后mosaic的结果是:
两条直线的交接点就是(cut_x, cut_y), 第一块切割点是(left_shift, top_shift),长度是cut_x, cut_y, 如何图片够宽,够长,left_shift + cut_x < w, 就可以直接切,若是不够长, left_shift就向上移动,使得left_shift + cut_x < w, 高是同理。其他图片也是这样切,只不过第2帧切的是右上角,第3帧切的是左下角,第4帧切的是右下角。