以下代码取自YOLOX中mosaicdetection.py文件
class MosaicDetection(Dataset):
"""Detection dataset wrapper that performs mixup for normal dataset."""
def __init__(
self, dataset, img_size, mosaic=True, preproc=None,
degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5),
shear=2.0, perspective=0.0, enable_mixup=True, *args
):
"""
Args:
dataset(Dataset) : Pytorch dataset object.
img_size (tuple):
mosaic (bool): enable mosaic augmentation or not.
preproc (func):
degrees (float):
translate (float):
scale (tuple):
mscale (tuple):
shear (float):
perspective (float):
enable_mixup (bool):
*args(tuple) : Additional arguments for mixup random sampler.
"""
super().__init__(img_size, mosaic=mosaic)
self._dataset = dataset
self.preproc = preproc
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.perspective = perspective
self.mixup_scale = mscale
self.enable_mosaic = mosaic
self.enable_mixup = enable_mixup
def __len__(self):
return len(self._dataset)
@Dataset.resize_getitem
def __getitem__(self, idx): # idx 图片索引
if self.enable_mosaic:
mosaic_labels = []
input_dim = self._dataset.input_dim
input_h, input_w = input_dim[0], input_dim[1] # 800*1440
# yc, xc = s, s # mosaic center x, y 拼接后图片的中心
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
# 3 additional image indices
indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)] # indices存四张图片的索引
for i_mosaic, index in enumerate(indices):
img, _labels, _, _ = self._dataset.pull_item(index) # pull_item
h0, w0 = img.shape[:2] # orig hw
scale = min(1. * input_h / h0, 1. * input_w / w0) # input_h、input_w为期望值,h0,w0为图片实际大小
img = cv2.resize( # 对图片按期望大小进行比例缩放
img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR
)
# generate output mosaic image
(h, w, c) = img.shape[:3] # h,w,c为处理完后图片的大小
if i_mosaic == 0: # 第一张图片
mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8) # 生成一个灰度掩膜 114
# suffix l means large image, while s means small image in mosaic aug.
(l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(
mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w
)
mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2] # 把图片对应部分映射到掩膜上
padw, padh = l_x1 - s_x1, l_y1 - s_y1
labels = _labels.copy()
# Normalized xywh to pixel xyxy format
if _labels.size > 0: # 对标签框进行缩放加偏移
labels[:, 0] = scale * _labels[:, 0] + padw
labels[:, 1] = scale * _labels[:, 1] + padh
labels[:, 2] = scale * _labels[:, 2] + padw
labels[:, 3] = scale * _labels[:, 3] + padh
mosaic_labels.append(labels)
if len(mosaic_labels):
mosaic_labels = np.concatenate(mosaic_labels, 0)
'''
np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
'''
# 去除超出画面的框
mosaic_labels = mosaic_labels[mosaic_labels[:, 0] < 2 * input_w]
mosaic_labels = mosaic_labels[mosaic_labels[:, 2] > 0]
mosaic_labels = mosaic_labels[mosaic_labels[:, 1] < 2 * input_h]
mosaic_labels = mosaic_labels[mosaic_labels[:, 3] > 0]
# 下面是图片的变换、mixup等操作
#augment_hsv(mosaic_img)
mosaic_img, mosaic_labels = random_perspective(
mosaic_img,
mosaic_labels,
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
shear=self.shear,
perspective=self.perspective,
border=[-input_h // 2, -input_w // 2],
) # border to remove
# -----------------------------------------------------------------
# CopyPaste: https://arxiv.org/abs/2012.07177
# -----------------------------------------------------------------
if self.enable_mixup and not len(mosaic_labels) == 0:
mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim)
mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim)
img_info = (mix_img.shape[1], mix_img.shape[0])
return mix_img, padded_labels, img_info, np.array([idx])
else:
self._dataset._input_dim = self.input_dim
img, label, img_info, id_ = self._dataset.pull_item(idx)
img, label = self.preproc(img, label, self.input_dim)
return img, label, img_info, id_
pull_item函数:
def pull_item(self, index): # index是取样器产生的 索引
id_ = self.ids[index]
res, img_info, file_name = self.annotations[index]
# load image and preprocess ?尺寸缩放? 视频5、数据加载部分代码讲解 20:18
img_file = os.path.join(
self.data_dir, self.name, file_name
)
img = cv2.imread(img_file)
assert img is not None
return img, res.copy(), img_info, np.array([id_])
get_mosaic_coordinate函数:
def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w): #xc,yc为拼接图像“中心”点,input_h,input_w为期望值,h,w为图片按比例调整后的大小
# TODO update doc
# index0 to top left part of image
if mosaic_index == 0:
x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
small_coord = w - (x2 - x1), h - (y2 - y1), w, h
# index1 to top right part of image
elif mosaic_index == 1:
x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
# index2 to bottom left part of image
elif mosaic_index == 2:
x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
# index2 to bottom right part of image
elif mosaic_index == 3:
x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa
small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
return (x1, y1, x2, y2), small_coord