前段时间参加了阿里天池的FashionAI服饰关键点定位比赛,为了做比赛,博主尝试用Detectron里面的Mask R-CNN去做关键点定位,取得了一定效果,也算是对Detectron的一些实践,特此做一些记录,希望对需要的朋友有所帮助。
Mask R-CNN是何凯明等人在Faster R-CNN基础上提出的一个优秀的目标实例分割模型。该模型能够有效地检测图像中的目标并为每个实例生成高质量的分割掩码。如下图所示,该模型通过在Faster R-CNN已存在的bbox识别分支旁并行地添加一个用于预测目标掩码的分支。掩码分支是一个应用到每个RoI上的小型FCN(全卷积网络),能够预测RoI中每个像素所属的类别,从而实现准确的实例分割。
Mask R-CNN的技术要点主要有三点:
关于Mask R-CNN的详细解读可以参见这篇博客。
Mask R-CNN训练简单,速度与Faster R-CNN相当,而且可以很容易的推广到其它与实例水平的识别相关的任务中,如目标检测和人体姿态估计。比如Mask R-CNN用于人体姿态估计时,主要是通过定位人体关键点来实现的,人体关键点在Mask R-CNN中可以被视作单个像素的mask。在第二部分,我将简单地介绍如何用Detectron的Mask R-CNN模型进行服饰关键点的定位。
与该系列上一篇博客中介绍的准备数据集的方法类似,不同之处是这次要自己做COCO json格式的关键点annotation文件了,COCO json数据格式详见官网。
COCO用于目标检测的annotation文件格式大致是下面这样的:
annotation{
"id" : int,
"image_id" : int,
"category_id" : int,
"segmentation" : RLE or [polygon],
"area" : float,
"bbox" : [x,y,width,height],
"iscrowd" : 0 or 1,
}
categories[{
"id" : int, "name" : str, "supercategory" : str,
}]
而关键点检测的annotation格式则是在上面格式的annotation字段和categories字段分别增加了一下内容:
annotation{
"keypoints" : [x1,y1,v1,...], #每个关键点以三元组x,y,v表示,v标识了关键点的存在性和可见性,为0表示不存在,为1表示不可见,为2表示可见
"num_keypoints" : int,
"[cloned]" : ...,
}
categories[{
"keypoints" : [str],
"skeleton" : [edge], #由关键点按一定的顺序连接而成
"[cloned]" : ...,
}]
"[cloned]": denotes fields copied from object detection annotations defined above.
再来看一看COCO数据集annotation文件到底是个啥样子
# person_keypoint_val2014.json内容
# 字典的5个字段
info
licenses
images
annotations
categories
#各类字段的长度和对应的第一个元素
#info
6 {u'description': u'COCO 2014 Dataset', u'url': u'http://cocodataset.org', u'version': u'1.0', u'year': 2014, u'contributor': u'COCO Consortium', u'date_created': u'2017/09/01'}
#license
8 {u'url': u'http://creativecommons.org/licenses/by-nc-sa/2.0/', u'id': 1, u'name': u'Attribution-NonCommercial-ShareAlike License'}
#images
40504 {u'license': 3, u'file_name': u'COCO_val2014_000000391895.jpg', u'coco_url': u'http://images.cocodataset.org/val2014/COCO_val2014_000000391895.jpg', u'height': 360, u'width': 640, u'date_captured': u'2013-11-14 11:18:45', u'flickr_url': u'http://farm9.staticflickr.com/8186/8119368305_4e622c8349_z.jpg', u'id': 391895}
#categories
1 {u'supercategory': u'person', u'id': 1, u'name': u'person', u'keypoints': [u'nose', u'left_eye', u'right_eye', u'left_ear', u'right_ear', u'left_shoulder', u'right_shoulder', u'left_elbow', u'right_elbow', u'left_wrist', u'right_wrist', u'left_hip', u'right_hip', u'left_knee', u'right_knee', u'left_ankle', u'right_ankle'], u'skeleton': [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]}
#annotations
88153 {u'segmentation': [[267.03, 243.78, 314.59, 154.05, 357.84, 136.76, 374.05, 104.32, 410.81, 110.81, 429.19, 131.35, 420.54, 165.95, 451.89, 209.19, 464.86, 240.54, 480, 253.51, 484.32, 263.24, 496.22, 271.89, 484.32, 278.38, 438.92, 257.84, 401.08, 216.76, 370.81, 247.03, 414.05, 277.3, 433.51, 304.32, 443.24, 323.78, 400, 362.7, 376.22, 375.68, 400, 418.92, 394.59, 424.32, 337.3, 382.16, 337.3, 371.35, 388.11, 327.03, 341.62, 301.08, 311.35, 276.22, 304.86, 263.24, 294.05, 249.19]], u'num_keypoints': 8, u'area': 28292.08625, u'iscrowd': 0, u'keypoints': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 325, 160, 2, 398, 177, 2, 0, 0, 0, 437, 238, 2, 0, 0, 0, 477, 270, 2, 287, 255, 1, 339, 267, 2, 0, 0, 0, 423, 314, 2, 0, 0, 0, 355, 367, 2], u'image_id': 537548, u'bbox': [267.03, 104.32, 229.19, 320], u'category_id': 1, u'id': 183020}
知道了关键点的annotation文件格式和具体内容就可以参照着制作自己的annotation文件了,其中关键的三个字段是images,annotations和categories,另外两个字段并不是很要紧。不过要注意的一点是由于Detectron中目前只支持人体关键点检测,因此我们用自己的数据集做的annotation文件中的categories字段的supercategory和name的值都要是person,否则会报错。
下面是我做好的服饰关键点annotation文件的大致样子
info
licenses
images
annotations
categories
2 {u'url': u'https://tianchi.aliyun.com/competition/introduction.htm?spm=5176.100067.5678.1.510044a9T425c6&raceId=231648', u'description': u'FashionAI Dataset'}
1 {u'url': u'http://creativecommons.org/licenses/by-nc-nd/2.0/', u'id': 3, u'name': u'Attribution-NonCommercial-NoDerivs License'}
2292 {u'file_name': u'Images/skirt/bd969a01fe65b95f2736e22d76b214d7.jpg', u'height': 512, u'id': 2997, u'license': 3, u'width': 512}
1 {u'supercategory': u'person', u'id': 1, u'keypoints': [u'waistband_left', u'waistband_right', u'hemline_left', u'hemline_right'], u'name': u'person'}
2292 {u'segmentation': [[41, 137, 41, 288, 499, 288, 499, 137]], u'num_keypoints': 4, u'area': 69158, u'iscrowd': 0, u'keypoints': [205, 155, 2, 324, 142, 2, 46, 283, 2, 494, 266, 2], u'ignore': 0, u'image_id': 2997, u'bbox': [41, 137, 458, 151], u'category_id': 1, u'id': 2997}
可以参照Detectron/configs/12_2017_baselines中的e2e_keypoint_rcnn_R-50-FPN_1x.yaml配置文件进行相应的修改,主要是增加KRCNN的配置
KRCNN:
ROI_KEYPOINTS_HEAD: keypoint_rcnn_heads.add_roi_pose_head_v1convX
NUM_STACKED_CONVS: 8
NUM_KEYPOINTS: 4 # 把此处改成自己数据集中需要检测的关键点个数
USE_DECONV_OUTPUT: True
CONV_INIT: MSRAFill
CONV_HEAD_DIM: 512
UP_SCALE: 2
HEATMAP_SIZE: 56 # ROI_XFORM_RESOLUTION (14) * UP_SCALE (2) * USE_DECONV_OUTPUT (2)
ROI_XFORM_METHOD: RoIAlign
ROI_XFORM_RESOLUTION: 14
ROI_XFORM_SAMPLING_RATIO: 2
KEYPOINT_CONFIDENCE: bbox
self.keypoint_flip_map={u'waistband_left':u'waistband_right', u'hemline_left':u'hemline_right'}
keypoints = [u'waistband_left', u'waistband_right', u'hemline_left', u'hemline_right']
keypoint_flip_map={u'waistband_left':u'waistband_right', u'hemline_left':u'hemline_right'}
执行train_net.py文件训练模型(相关参数自己按需提供,比如–cfg,OUTPUT_DIR等)
这里因为我的关键点annotation文件没有提供skeleton,没办法直接用Detectron的infer_simple.py进行推断,就在infer_simple.py文件中仿照main函数自己写了个函数,将推断结果(关键点坐标)输出为csv格式的文件,主要就是利用main函数里面的infer_engine.im_detect_all()函数拿到推断的结果。
def write_infer_kpts(args, file_name):
"""write infer result of FashionAI keypoints to csv file"""
logger = logging.getLogger(__name__)
merge_cfg_from_file(args.cfg)
cfg.TEST.WEIGHTS = args.weights
cfg.NUM_GPUS = 1
assert_and_infer_cfg()
model = infer_engine.initialize_model_from_cfg()
bbox_infer = []
csv_head = ['image_id', 'image_category', 'xmin','ymin','xmax','ymax', 'waistband_left','waistband_right','hemline_left','hemline_right'
]
input_data = pd.read_csv(args.input_data).values
im_list = map(lambda x: args.im_or_folder+x, list(input_data[:,0]))
for i, im_name in enumerate(im_list):
logger.info('Processing {}'.format(im_name))
im = cv2.imread(im_name)
timers = defaultdict(Timer)
t = time.time()
with c2_utils.NamedCudaScope(0):
cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
model, im, None, timers=timers
)
logger.info('Inference time: {:.3f}s'.format(time.time() - t))
for k, v in timers.items():
logger.info(' | {}: {:.3f}s'.format(k, v.average_time))
if i == 0:
logger.info(
' \ Note: inference on the first image will be slower than the '
'rest (caches and auto-tuning need to warm up)'
)
im_id = im_name.split('/')[-3]+'/'+im_name.split('/')[-2]+'/'+im_name.split('/')[-1]
im_label = im_name.split('/')[-2]
cls_keyps = np.array(cls_keyps[1])
cls_boxes = np.array(cls_boxes[1])# cls_boxes[0] represents background
print(i, cls_keyps.shape, cls_boxes.shape)
# print('cls_keyps:',cls_keyps)
kpts_num = len(csv_head[6:])
# item_list = [im_id, im_label]
if len(cls_boxes[0])>0:
idx = np.argsort(cls_boxes[:,4])[-1]
xmin, ymin, xmax, ymax = map(lambda x: int(round(x)), cls_boxes[idx][:4])
item_list = [im_id, im_label, xmin, ymin, xmax, ymax]
if cls_keyps.shape[0]>0:
for i in range(kpts_num):
idx = np.argsort(cls_keyps[:,3,i])[-1]
kpt_x, kpt_y = map(lambda x: int(round(x)), list(cls_keyps[idx,:2,i]))
item_list.append(str(kpt_x)+'_'+str(kpt_y)+'_1')
bbox_infer.append(item_list)
df = pd.DataFrame(bbox_infer, columns = csv_head)
df.to_csv(args.output_dir+'/'+file_name, mode='w')
2.6 服饰关键点定位结果
同时比较准确的预测出了skirt的bounding box和四角的关键点。
用Mask R-CNN训练自己的数据集进行关键点定位大致过程就是这样,希望对需要的朋友有所帮助。