因为刚接触two-stage表示方法以及实例分割算法,而且正好ODAI项目是个目标检测任务,所以就使用maskRCNN作为baseline。初步思路是将DOTA数据集转化为coco数据集的格式,扔入MaskRCNN中训练,感觉是一个很简单的过程,但是实际上在实践中就遇到了很多问题。
第一步是要把DOTA数据集的格式转化为MaskRCNN能识别的coco数据集格式。下面先看DOTA数据集里的格式示例
主要分为5种类型的数据:
1.imagesource 图片来源
2.gsd 相当于比例尺
3.8个坐标值 表示boundingbox(不使用x,y,w,h表示的原因是这个数据集里的bbox可能是斜的)
4.category 16个分类中的一个
5.Difficulty 是否难以识别
转数据集格式:
{
# coco数据集格式
"info": info, # 可省略
"licenses": [license], # 可省略
"images": [image], # 有前三个就够了,后面置空字符串
#image列表,每个image有file_name,height,width,license,coco_url,date_captured,flickr_urlid
"annotations": [annotation], # 都需要转换
#annotation列表,每个annotation有id,image id,category id,segmentation,area,bbox,iscrowd
"categories": [category] # supercategory可以置空字符串
#categories列表,每个category有id,name,supercategory
}
文末贴出转格式的代码,结构混乱。。求轻喷。。
我记得只需要改几个小地方就能运行了:
model.load_weights(model_path, by_name=True,
exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
"mrcnn_bbox", "mrcnn_mask"])
处理到这里就可以训练一下试试了
6. 好像还有个很玄学的问题就是在model.py里面有这么一段代码,好像会导致在windows下训练的时候训练卡在epoch1不动(多线程死锁)的问题,github上MaskRCNN项目的讨论里有些人把workers改成1,把use_multiprocessing改成False就好了。如果完成前几步后正常训练没问题的话这里就不要改了。
self.keras_model.fit_generator(
train_generator,
initial_epoch=self.epoch,
epochs=epochs,
steps_per_epoch=self.config.STEPS_PER_EPOCH,
callbacks=callbacks,
validation_data=val_generator,
validation_steps=self.config.VALIDATION_STEPS,
workers=workers,
use_multiprocessing=True,
)
贴一下转coco格式的代码(代码结构混乱…命名不清…大家将就着看…):
# -*- coding:utf-8 -*-
import os
import cv2
import json
import pprint
import numpy as np
from PIL import Image
category_dict = {'plane': 0, 'ship': 1, 'storage-tank': 2, 'baseball-diamond': 3, 'tennis-court': 4, 'basketball-court': 5,
'ground-track-field': 6, 'harbor': 7, 'bridge': 8, 'small-vehicle': 9, 'large-vehicle': 10,
'helicopter': 11, 'roundabout': 12, 'soccer-ball-field': 13, 'swimming-pool': 14, 'container-crane': 15}
w_list=[]
h_list=[]
rate_list=[]
def read_json():
with open("instances_val2014.json", 'r') as load_f:
load_dict = json.load(load_f)
print(load_dict['annotations'])
for i in load_dict:
print(i)
input()
def extract_seg_RLE(size, ori_seg): # 8个点的坐标
img = np.zeros((size[0], size[1]), np.uint8)
ori_seg = np.asfarray(ori_seg).reshape(4, 2)
pts = np.array([ori_seg], np.int32)
pts = pts.reshape((-1, 1, 2))
# print(pts)
cv2.fillPoly(img, [pts], 255) # 为什么前几个1k以上的坐标数据填充不上???
# cv2.imshow('line', img)
# cv2.waitKey()
img /= 255
img = list(img.flatten())
img.append(2)
rle_out = []
count0 = 0
flag0 = False
flag1 = False
count1 = 0
for i in img:
if i == 0 and flag0 == True:
count0 += 1
elif i == 0 and (flag1 == True or flag0 == False):
rle_out.append(count1)
count1 = 0
flag1 = False
flag0 = True
count0 += 1
elif i == 1 and flag1 == True:
count1 += 1
elif i == 1 and (flag0 == True or flag1 == False):
rle_out.append(count0)
count0 = 0
flag1 = True
flag0 = False
count1 += 1
elif i == 2:
if count0 > 0:
rle_out.append(count0)
else:
rle_out.append(count1)
rle_out = rle_out[1:]
return rle_out
def get_category():
cate = []
for i in category_dict:
c = {}
c.update(supercategory='')
c.update(id=category_dict[i])
c.update(name=i)
cate.append(c)
return cate
def get_images(pt):
path = 'coco/dataset/'+pt+'2019/'
path_list = os.listdir(path)
path_list.sort() # 对读取的路径进行排序
images_list = []
for filename in path_list:
lis = {}
img = Image.open(path + filename)
lis.update(license=1)
lis.update(file_name=filename)
lis.update(coco_url='')
lis.update(width=int(img.size[0]))
lis.update(height=int(img.size[1]))
lis.update(date_captured='')
lis.update(flickr_url='')
lis.update(id=int(filename[1:-4]))
images_list.append(lis)
return images_list
def get_anno(pt):
import os
all_annotation_id = 0
#path = pt+"/labelTxt-v1.5/DOTA-v1.5_"+pt # 待读取的文件夹
path=pt+'/labelTxt-v1.5/DOTA-v1.5_'+pt+'/'
pic_path = 'coco/dataset/'+pt+'2019/'
path_list = os.listdir(path)
path_list.sort() # 对读取的路径进行排序
number = 0
anno_list = []
for filename in path_list:
pic_id = int(filename[1:-4])
# print(filename)
with open(os.path.join(path, filename), 'r',encoding='utf-8') as file_to_read: # 一张图片中所有target的描述
image_source = file_to_read.readline().strip() # image source 第一行
gsd = file_to_read.readline().strip() # gsd 第二行
lines = file_to_read.readline().strip() # target物体
pic_name = filename[:-4] + '.png'
img = Image.open(pic_path + pic_name)
while lines:
lines = lines.split()
box_x = list(map(float, [lines[0], lines[2], lines[4], lines[6]]))
box_y = list(map(float, [lines[1], lines[3], lines[5], lines[7]]))
box_height = float(max(box_y)) - float(min(box_y))
box_width = float(max(box_x)) - float(min(box_x))
w_list.append(box_width)
h_list.append(box_height)
rate_list.append(box_width/box_height)
# print('h:',box_height,' w:',box_width,' w/h:',box_width/box_height)
box_area = box_width * box_height
help = {}
# seg = {}
# seg.update(counts=extract_seg_RLE(img.size, map(float, lines[:8])))
# seg.update(size=img.size)
# help.update(segmentation=seg)
help.update(segmentation=[list(map(float, lines[:8]))])
help.update(area=box_area)
help.update(bbox=[min(box_x), min(box_y), box_width, box_height])
help.update(iscrowd=0)
help.update(image_id=pic_id)
help.update(id=all_annotation_id)
category_id = category_dict[lines[8]]
help.update(category_id=category_id)
all_annotation_id += 1
lines = file_to_read.readline().strip() # 整行读取数据
anno_list.append(help)
number += 1
print(number)
return anno_list
if __name__ == '__main__':
cat=get_category()
images=get_images('val')
anno=get_anno('val')
coco={}
# info licenses
coco.update(images=images)
coco.update(annotations=anno)
coco.update(categories=cat)
json.dump(coco, open('instances_val2019' + ".json", 'w'))
print('wmax:',max(w_list))
print('wmin:',min(w_list))
print('hmax:',max(h_list))
print('hmin:',min(h_list))
print('ratemax:',max(rate_list))
print('ratemin:',min(rate_list))
input()
cat2=get_category()
images2=get_images('train')
anno2=get_anno('train')
coco2={}
# info licenses
coco2.update(images=images2)
coco2.update(annotations=anno2)
coco2.update(categories=cat2)
json.dump(coco2, open('instances_train2019' + ".json", 'w'))