玩转Facebook的maskrcnn-benchmark项目 2

maskrcnn-benchmark是Facebook开源的基准(benchmark)算法工程,其中包含检测分割人体关键点等算法。

本系列包含两篇:

  • 第一篇 搭建环境;
  • 第二篇 训练和验证;

训练

使用maskrcnn-benchmark训练模型,可以参考。

数据集:

  • 下载完整的COCO数据集:annotations、test2014、train2014、val2014;
  • 下载FAIR提供的COCO小型验证集:minival和valminusminival;

选择训练模板:e2e_mask_rcnn_R_50_FPN_1x.yaml,其中:

WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"  # 预训练权重
DATASETS:  # 数据集
  TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
  TEST: ("coco_2014_minival",)
MAX_ITER: 90000  # 最大训练轮次
复制代码

其他参数的设置位置:maskrcnn_benchmark/config/defaults.py

如:

  • _C.SOLVER.CHECKPOINT_PERIOD = 2500,保存轮次;
  • _C.SOLVER.IMS_PER_BATCH = 16,训练的batch_size
  • _C.OUTPUT_DIR = "./models",模型输出路径;

指定GPU的数量:

export NGPUS=4
复制代码

训练模型:

python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"

nohup python -u -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" &
复制代码

输出的模型位于./models中,最后一个模型是model_0090000.pth


测试

将配置文件e2e_mask_rcnn_R_50_FPN_1x.my.yaml的WEIGHT修改为模型路径,如WEIGHT: "model_0090000.pth"。

测试逻辑如下:

def main():
    img_path = os.path.join(DATA_DIR, 'aoa-mina.jpeg')

    img = cv2.imread(img_path)
    print('[Info] img size: {}'.format(img.shape))
    
    config_file = "../configs/e2e_mask_rcnn_R_50_FPN_1x.my.yaml"

    cfg.merge_from_file(config_file)  # 设置配置文件
    cfg.merge_from_list(["MODEL.MASK_ON", True])
    cfg.merge_from_list(["MODEL.DEVICE", "cpu"])  # 指定为CPU

    coco_demo = COCODemo(  # 创建模型文件
        cfg,
        # show_mask_heatmaps=True,
        min_image_size=800,
        confidence_threshold=0.7,
    )

    predictions = coco_demo.compute_prediction(img)
    top_predictions = coco_demo.select_top_predictions(predictions)
    show_mask(img, top_predictions, coco_demo)

    print('执行完成!')
复制代码

show_mask显示分割的效果,以matplotlib的图像为基础,绘制多边形框:

def show_mask(img, predictions, coco_demo):
    """
    显示分割效果
    :param img: numpy的图像 
    :param predictions: 分割结果
    :param coco_demo: 函数集
    :return: 显示分割效果
    """
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    pylab.rcParams['figure.figsize'] = (8.0, 10.0)  # 图片尺寸
    plt.imshow(img)  # 需要提前填充图像
    # plt.show()

    extra_fields = predictions.extra_fields
    masks = extra_fields['mask']
    labels = extra_fields['labels']
    name_list = [coco_demo.CATEGORIES[l] for l in labels]

    seg_list = []
    for mask in masks:
        mask = torch.squeeze(mask)
        segmentation = binary_mask_to_polygon(mask, tolerance=2)[0]
        seg_list.append(segmentation)

    ax = plt.gca()
    ax.set_autoscale_on(False)

    polygons, color = [], []

    np.random.seed(37)

    for name, seg in zip(name_list, seg_list):
        c = (np.random.random((1, 3)) * 0.8 + 0.2).tolist()[0]

        poly = np.array(seg).reshape((int(len(seg) / 2), 2))
        c_x, c_y = get_center_of_polygon(poly)  # 计算多边形的中心点

        # 0~26是大类别, 其余是小类别 同时 每个标签只绘制一次
        tc = c - np.array([0.5, 0.5, 0.5])  # 降低颜色
        tc = np.maximum(tc, 0.0)  # 最小值0

        plt.text(c_x, c_y, name, ha='left', wrap=True, color=tc,
                 bbox=dict(facecolor='white', alpha=0.5))  # 绘制标签

        polygons.append(pylab.Polygon(poly))  # 绘制多边形
        color.append(c)

    p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)  # 添加多边形
    ax.add_collection(p)
    p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)  # 添加多边形的框
    ax.add_collection(p)

    plt.axis('off')
    ax.get_xaxis().set_visible(False)  # this removes the ticks and numbers for x axis
    ax.get_yaxis().set_visible(False)  # this removes the ticks and numbers for y axis

    out_folder = os.path.join(ROOT_DIR, 'demo', 'out')
    mkdir_if_not_exist(out_folder)
    out_file = os.path.join(out_folder, 'test.png'.format())
    plt.savefig(out_file, bbox_inches='tight', pad_inches=0, dpi=200)

    plt.close()  # 避免所有图像绘制在一起
复制代码

效果如下:


OK, that's all!

转载于:https://juejin.im/post/5cde298af265da7e603903b0

你可能感兴趣的:(玩转Facebook的maskrcnn-benchmark项目 2)