使用MaskRCNN训练自己的ODAI数据集的思路 遇到的问题及解决方案

1.baseline选择

因为刚接触two-stage表示方法以及实例分割算法,而且正好ODAI项目是个目标检测任务,所以就使用maskRCNN作为baseline。初步思路是将DOTA数据集转化为coco数据集的格式,扔入MaskRCNN中训练,感觉是一个很简单的过程,但是实际上在实践中就遇到了很多问题。

2.数据集格式转化问题

第一步是要把DOTA数据集的格式转化为MaskRCNN能识别的coco数据集格式。下面先看DOTA数据集里的格式示例
使用MaskRCNN训练自己的ODAI数据集的思路 遇到的问题及解决方案_第1张图片
主要分为5种类型的数据:
1.imagesource 图片来源
2.gsd 相当于比例尺
3.8个坐标值 表示boundingbox(不使用x,y,w,h表示的原因是这个数据集里的bbox可能是斜的)
4.category 16个分类中的一个
5.Difficulty 是否难以识别

转数据集格式:

{
    # coco数据集格式
    "info": info, # 可省略
    "licenses": [license], # 可省略
    "images": [image], # 有前三个就够了,后面置空字符串
    #image列表,每个image有file_name,height,width,license,coco_url,date_captured,flickr_urlid
    "annotations": [annotation], # 都需要转换
    #annotation列表,每个annotation有id,image id,category id,segmentation,area,bbox,iscrowd
    "categories": [category] # supercategory可以置空字符串
    #categories列表,每个category有id,name,supercategory
}
  1. 因为ODAI中只有bbox(还是倾斜的),并没有segmentation的概念,所以我干脆将DOTA-bbox扩展为coco-bbox和coco-segmentation两个概念,想法很简单,就是将原始的DOTA-bbox直接看作coco-segmentation,用一个横平竖直的框套住DOTA-bbox作为coco-bbox。category就用一个字典。键值是它的标号,作为coco-category id
  2. coco-segmentation有两种格式,我搞了半天才弄懂。
    [polygen] iscrowd为0时的格式,每个点用xy坐标表示,相当于用好多点把一个segmentation从图上划分出来了
    [RLE] iscrowd为1时的格式,相当于把图片看作一副同样大小的,用0/1表示的坐标矩阵,为1时就是此处有segmentation,就是在图上作掩码。把这个01矩阵展开成一维,按计数的方法表示为RLE格式
    eg: [0,0,0,1,0,0,1,1]->[3个0,1个1,2个0,2个1]->[3,0,1,1,2,0,2,1](好像是这样,我最终没采用这个方法测试)
  3. iscrowd有两种格式,我也搞了半天才弄懂
    iscrowd=1 表示这个segmentation里面有不止一个目标物体 但是分不开(俩人如胶似漆的抱在一起)
    iscrowd=0 表示这个segmentation里面就一个目标物体(一只单身狗)
    (好像就这么简单。。我咋看了那么久才理解呢。。。)

文末贴出转格式的代码,结构混乱。。求轻喷。。

3.把数据集扔到MaskRCNN里训练

我记得只需要改几个小地方就能运行了:

  1. 自己的数据集目录以及命名方式要和coco一致(train2014/instances_train2014.json),图片文件名不一致没关系
  2. log文件的路径,coco预训练好的权重路径自己定义一下(源代码里也有,看需要改)
  3. class CocoConfig里重定义config的一些配置:
    NUM_CLASSES=1+你的数据有多少类
    STEPS_PER_EPOCH=MaskRCNN设的1000,自己测试的时候可以开小一点以便检查错误,真正跑的时候再改回来
    VALIDATION_STEPS计算loss的,开大了容易让每个epoch的最后一步变得贼慢
    BACKBONE据说如果电脑算力不行的话用resnet50(好像),MaskRCNN原始设的"resnet101",按需改吧
  4. 还有就是比较trick的地方是RPN_ANCHOR_SCALES要根据你图片的特征改一下,DOTA数据集里的目标太小了,所以我设的就是(8, 16, 32, 64,128),同样比较trick的就是anchor的比例那个config
  5. 很重要的一点,如果数据集的类别个数和coco的类别个数不一样的时候,使用coco预训练的权重mask_rcnn_coco.h5会报错(忘了什么错了反正运行不了)要指定load_weights的参数,忽略不相容的部分
model.load_weights(model_path, by_name=True,
                           exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
                                    "mrcnn_bbox", "mrcnn_mask"])

处理到这里就可以训练一下试试了
6. 好像还有个很玄学的问题就是在model.py里面有这么一段代码,好像会导致在windows下训练的时候训练卡在epoch1不动(多线程死锁)的问题,github上MaskRCNN项目的讨论里有些人把workers改成1,把use_multiprocessing改成False就好了。如果完成前几步后正常训练没问题的话这里就不要改了。

self.keras_model.fit_generator(
            train_generator,
            initial_epoch=self.epoch,
            epochs=epochs,
            steps_per_epoch=self.config.STEPS_PER_EPOCH,
            callbacks=callbacks,
            validation_data=val_generator,
            validation_steps=self.config.VALIDATION_STEPS, 
            workers=workers,
            use_multiprocessing=True,
        )

4.训练过程中的问题

  • 卡在epoch1/40死活动不了。我尝试了3.6说的方法,完全没有改进,在这个问题上愁了整整两天,把git上的类似解决方案(修改上面那个函数,更改keras版本,更改tensorflow版本,调整image里的缩放参数,在windows上跑,在服务器(Ubuntu 16.04.3 LTS,GTX 1080Ti)上跑)试了一个遍都没用。
  • 后来找了做视觉的老师,描述了一下数据集的大小格式处理方式,老师说我送到网络里的图片太大了(1000多*2000多像素
    甚至更大),每张图都有3/4兆,每张图中目标个数平均有400个,有的甚至能达到上千。如果不经过预处理就放上去网络性能变差那是必然的。
  • 于是我想起来ODAI有个工具包DOTA_devkit-master可以对图片进行预处理,发现果然有切割,就改了改代码,适应了不同图片的长宽比以及目标个数进行滑动切分(crop)。
  • 又根据切分出来的结果 改了改转格式的代码转成coco的json格式,满怀期待的扔到服务器上训练
  • 搞!定!啦!

5.现在很棘手的问题(求助)

  1. 使用MaskRCNN的evaluate在经过切分后的validation数据集上测试,速度太慢(500张图跑了快1700秒)
  2. 最可怕的是预测出来AP全都是0…
    使用MaskRCNN训练自己的ODAI数据集的思路 遇到的问题及解决方案_第2张图片
    刚开始学习,有写的不对的地方还望大家多多指正
    如果大佬有解决以上两个问题的思路请速速联系我QAQ

贴一下转coco格式的代码(代码结构混乱…命名不清…大家将就着看…):

# -*- coding:utf-8 -*-
import os
import cv2
import json
import pprint
import numpy as np
from PIL import Image

category_dict = {'plane': 0, 'ship': 1, 'storage-tank': 2, 'baseball-diamond': 3, 'tennis-court': 4, 'basketball-court': 5,
            'ground-track-field': 6, 'harbor': 7, 'bridge': 8, 'small-vehicle': 9, 'large-vehicle': 10,
            'helicopter': 11, 'roundabout': 12, 'soccer-ball-field': 13, 'swimming-pool': 14, 'container-crane': 15}

w_list=[]
h_list=[]
rate_list=[]
def read_json():
    with open("instances_val2014.json", 'r') as load_f:
        load_dict = json.load(load_f)
        print(load_dict['annotations'])
        for i in load_dict:
            print(i)
            input()


def extract_seg_RLE(size, ori_seg):  # 8个点的坐标
    img = np.zeros((size[0], size[1]), np.uint8)
    ori_seg = np.asfarray(ori_seg).reshape(4, 2)
    pts = np.array([ori_seg], np.int32)
    pts = pts.reshape((-1, 1, 2))
    # print(pts)
    cv2.fillPoly(img, [pts], 255)  # 为什么前几个1k以上的坐标数据填充不上???
    # cv2.imshow('line', img)
    # cv2.waitKey()
    img /= 255
    img = list(img.flatten())
    img.append(2)
    rle_out = []
    count0 = 0
    flag0 = False
    flag1 = False
    count1 = 0
    for i in img:
        if i == 0 and flag0 == True:
            count0 += 1
        elif i == 0 and (flag1 == True or flag0 == False):
            rle_out.append(count1)
            count1 = 0
            flag1 = False
            flag0 = True
            count0 += 1
        elif i == 1 and flag1 == True:
            count1 += 1
        elif i == 1 and (flag0 == True or flag1 == False):
            rle_out.append(count0)
            count0 = 0
            flag1 = True
            flag0 = False
            count1 += 1
        elif i == 2:
            if count0 > 0:
                rle_out.append(count0)
            else:
                rle_out.append(count1)
    rle_out = rle_out[1:]
    return rle_out


def get_category():
    cate = []
    for i in category_dict:
        c = {}
        c.update(supercategory='')
        c.update(id=category_dict[i])
        c.update(name=i)
        cate.append(c)
    return cate


def get_images(pt):
    path = 'coco/dataset/'+pt+'2019/'
    path_list = os.listdir(path)
    path_list.sort()  # 对读取的路径进行排序
    images_list = []
    for filename in path_list:
        lis = {}
        img = Image.open(path + filename)
        lis.update(license=1)
        lis.update(file_name=filename)
        lis.update(coco_url='')
        lis.update(width=int(img.size[0]))
        lis.update(height=int(img.size[1]))
        lis.update(date_captured='')
        lis.update(flickr_url='')
        lis.update(id=int(filename[1:-4]))
        images_list.append(lis)
    return images_list


def get_anno(pt):
    import os
    all_annotation_id = 0
    #path = pt+"/labelTxt-v1.5/DOTA-v1.5_"+pt  # 待读取的文件夹
    path=pt+'/labelTxt-v1.5/DOTA-v1.5_'+pt+'/'
    pic_path = 'coco/dataset/'+pt+'2019/'
    path_list = os.listdir(path)
    path_list.sort()  # 对读取的路径进行排序
    number = 0
    anno_list = []
    for filename in path_list:
        pic_id = int(filename[1:-4])
        # print(filename)
        with open(os.path.join(path, filename), 'r',encoding='utf-8') as file_to_read:  # 一张图片中所有target的描述
            image_source = file_to_read.readline().strip()  # image source 第一行
            gsd = file_to_read.readline().strip()  # gsd 第二行
            lines = file_to_read.readline().strip()  # target物体
            pic_name = filename[:-4] + '.png'
            img = Image.open(pic_path + pic_name)

            while lines:
                lines = lines.split()
                box_x = list(map(float, [lines[0], lines[2], lines[4], lines[6]]))
                box_y = list(map(float, [lines[1], lines[3], lines[5], lines[7]]))
                box_height = float(max(box_y)) - float(min(box_y))
                box_width = float(max(box_x)) - float(min(box_x))
                w_list.append(box_width)
                h_list.append(box_height)
                rate_list.append(box_width/box_height)
                # print('h:',box_height,' w:',box_width,' w/h:',box_width/box_height)
                box_area = box_width * box_height
                help = {}
                # seg = {}
                # seg.update(counts=extract_seg_RLE(img.size, map(float, lines[:8])))
                # seg.update(size=img.size)
                # help.update(segmentation=seg)
                help.update(segmentation=[list(map(float, lines[:8]))])
                help.update(area=box_area)
                help.update(bbox=[min(box_x), min(box_y), box_width, box_height])
                help.update(iscrowd=0)
                help.update(image_id=pic_id)
                help.update(id=all_annotation_id)
                category_id = category_dict[lines[8]]
                help.update(category_id=category_id)
                all_annotation_id += 1
                lines = file_to_read.readline().strip()  # 整行读取数据
                anno_list.append(help)
        number += 1
        print(number)
    return anno_list
	
if __name__ == '__main__':
	cat=get_category()
	images=get_images('val')
	anno=get_anno('val')
	coco={}
	# info licenses
	coco.update(images=images)
	coco.update(annotations=anno)
	coco.update(categories=cat)
	json.dump(coco, open('instances_val2019' + ".json", 'w'))

	print('wmax:',max(w_list))
	print('wmin:',min(w_list))
	print('hmax:',max(h_list))
	print('hmin:',min(h_list))
	print('ratemax:',max(rate_list))
	print('ratemin:',min(rate_list))
	input()
	cat2=get_category()
	images2=get_images('train')
	anno2=get_anno('train')
	coco2={}
	# info licenses
	coco2.update(images=images2)
	coco2.update(annotations=anno2)
	coco2.update(categories=cat2)
	json.dump(coco2, open('instances_train2019' + ".json", 'w'))

你可能感兴趣的:(实习,目标检测,MaskRCNN,Objection,Dection,of,Arial,Images,目标检测)