faster rcnn 源码解读(二)

解读data_augment.py文件

def augment(img_data, config, augment=True):
	assert 'filepath' in img_data
	assert 'bboxes' in img_data
	assert 'width' in img_data
	assert 'height' in img_data
    augment是一个函数,该函数的输入是img_data是一个字典,也就是我们解读一里面返回值all_imgs里面的一个元素,之前说过,它里面的值其实就是一个字典,包括该图像里面的信息,目标和目标框。
assert是断言,也就是如果发现这个图像的信息不完全,没有这些元素那么就报错,触发异常。

img_data_aug = copy.deepcopy(img_data)

img = cv2.imread(img_data_aug['filepath'])

以上代码第一句是保证不改变原数据的情况下复制,第二句是使用opencv读取数据,

	if augment:
		rows, cols = img.shape[:2]

		if config.use_horizontal_flips and np.random.randint(0, 2) == 0:
			img = cv2.flip(img, 1)
			for bbox in img_data_aug['bboxes']:
				x1 = bbox['x1']
				x2 = bbox['x2']
				bbox['x2'] = cols - x1
				bbox['x1'] = cols - x2

		if config.use_vertical_flips and np.random.randint(0, 2) == 0:
			img = cv2.flip(img, 0)
			for bbox in img_data_aug['bboxes']:
				y1 = bbox['y1']
				y2 = bbox['y2']
				bbox['y2'] = rows - y1
				bbox['y1'] = rows - y2

		if config.rot_90:
			angle = np.random.choice([0,90,180,270],1)[0]
			if angle == 270:
				img = np.transpose(img, (1,0,2))
				img = cv2.flip(img, 0)
			elif angle == 180:
				img = cv2.flip(img, -1)
			elif angle == 90:
				img = np.transpose(img, (1,0,2))
				img = cv2.flip(img, 1)
			elif angle == 0:
				pass

			for bbox in img_data_aug['bboxes']:
				x1 = bbox['x1']
				x2 = bbox['x2']
				y1 = bbox['y1']
				y2 = bbox['y2']
				if angle == 270:
					bbox['x1'] = y1
					bbox['x2'] = y2
					bbox['y1'] = cols - x2
					bbox['y2'] = cols - x1
				elif angle == 180:
					bbox['x2'] = cols - x1
					bbox['x1'] = cols - x2
					bbox['y2'] = rows - y1
					bbox['y1'] = rows - y2
				elif angle == 90:
					bbox['x1'] = rows - y2
					bbox['x2'] = rows - y1
					bbox['y1'] = x1
					bbox['y2'] = x2        
				elif angle == 0:
					pass

    以上的代码是对数据进行增强,水平移动,垂直移动,90度和180度旋转等,同时他们的框也是需要跟着变得,

for bbox in img_data_aug['bboxes']  这个循环的原因是因为每一个图像里面不止一个框,有很多框,这个是为了遍历所有的框,因为bboxes为key,他的框的信息是一个list,list有存有不同的框的字典。所以这个循环就是对图像中所有的框都进行一个变换,保持和图像同步,这个地方的变换值得学习。

img_data_aug['width'] = img.shape[1]
	img_data_aug['height'] = img.shape[0]
	return img_data_aug, img

这李就是记录图像的长和宽,因为图像经过变换了,反回的是变换后的图像的信息,以及原始图像。


总结,这个函数就是对数据进行增强,包括平移,旋转等。代码值得学习


你可能感兴趣的:(深度学习)