代码链接:
https://github.com/xingyizhou/CenterNet/tree/master/readme
这个工程的环境配置起来很费劲,根据教程配好环境,py=3.6,torch=0.4.1,运行的时候报错
ImportError:/home/shiep/CenterNet/src/lib/models/networks/DCNs/_ext/dcn_v2/dcn_v2.so: undefined symbol: __cudaRegisterFatBinaryEnd
解决方法:驱动版本太高,重装驱动太麻烦。于是找到了另一个工程:
https://github.com/shenyi0220/centernet-cp-cluster
这个在我之前配好的环境里正常运行。
一、处理数据集
该工程用到的是coco数据集,要把xml的转成json
# coding:utf-8
# 运行前请先做以下工作:
# pip install lxml
# 将所有的图片及xml文件存放到xml_dir指定的文件夹下,并将此文件夹放置到当前目录下
#
import os
import glob
import json
import shutil
import cv2
import numpy as np
import xml.etree.ElementTree as ET
START_BOUNDING_BOX_ID = 1
save_path = "."
names = ["aa", "bb", "cc"]
txt_path = "labels/"
img_path="images/"
def get(root, name):
return root.findall(name)
def get_and_check(root, name, length):
vars = get(root, name)
if len(vars) == 0:
raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
if length and len(vars) != length:
raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
if length == 1:
vars = vars[0]
return vars
def convert(xml_list, json_file):
json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
categories = pre_define_categories.copy()
bnd_id = START_BOUNDING_BOX_ID
all_categories = {}
for index, fi in enumerate(xml_list):
print("Processing %s"%(fi))
xml_f = fi
filename = os.path.basename(xml_f)[:-4] + ".jpg"
image_id = 20190000001 + index
try:
img=cv2.imread(img_path+filename)
height, width, channel = img.shape
except:
try:
filename = os.path.basename(xml_f)[:-4] + ".png"
img=cv2.imread(img_path+filename)
height, width, channel = img.shape
except:
filename = os.path.basename(xml_f)[:-4] + ".jpeg"
img=cv2.imread(img_path+filename)
height, width, channel = img.shape
image = {'file_name': filename, 'height': height, 'width': width, 'id': image_id}
json_dict['images'].append(image)
for line in open(xml_f):
line_all=line.split(" ")
category = names[int(line_all[0])]
# for obj in get(root, 'object'):
# category = get_and_check(obj, 'name', 1).text
if category in all_categories:
all_categories[category] += 1
else:
all_categories[category] = 1
category_id = int(line_all[0])
xmin = int(float(line_all[1])*width-float(line_all[3])*width/2)
ymin = int(float(line_all[2])*height-float(line_all[4])*height/2)
xmax = int(float(line_all[1])*width+float(line_all[3])*width/2)
ymax = int(float(line_all[2])*height+float(line_all[4])*height/2)
if xmax < xmin or ymax < ymin:
continue
o_width = abs(xmax - xmin)
o_height = abs(ymax - ymin)
ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id':
image_id, 'bbox': [xmin, ymin, o_width, o_height],
'category_id': category_id, 'id': bnd_id, 'ignore': 0,
'segmentation': []}
json_dict['annotations'].append(ann)
bnd_id = bnd_id + 1
for cate, cid in categories.items():
cat = {'supercategory': 'ball', 'id': cid, 'name': cate}
json_dict['categories'].append(cat)
json_fp = open(json_file, 'w')
json_str = json.dumps(json_dict)
json_fp.write(json_str)
json_fp.close()
print("------------create {} done--------------".format(json_file))
print("find {} categories: {} -->>> your pre_define_categories {}: {}".format(len(all_categories),
all_categories.keys(),
len(pre_define_categories),
pre_define_categories.keys()))
print("category: id --> {}".format(categories))
print(categories.keys())
print(categories.values())
if __name__ == '__main__':
# 定义你自己的类别
pre_define_categories = {}
for i, cls in enumerate(names):
pre_define_categories[cls] = i + 1
# 这里也可以自定义类别id,把上面的注释掉换成下面这行即可
# pre_define_categories = {'a1': 1, 'a3': 2, 'a6': 3, 'a9': 4, "a10": 5}
only_care_pre_define_categories = True # or False
# 保存的json文件
save_json_train = 'train_ship.json'
save_json_val = 'val_ship.json'
save_json_test = 'test_ship.json'
# 初始文件所在的路径
xml_dir = "/media/zhanglu/0bb0d537-0b35-45bf-94ec-9def4a6dd599/zhanglu/yolov5-fishi/data/ship/data_0519/labels"
xml_list = glob.glob(xml_dir + "/*.txt")
xml_list = np.sort(xml_list)
# 打乱数据集
np.random.seed(100)
np.random.shuffle(xml_list)
# 按比例划分打乱后的数据集
train_ratio = 0.8
val_ratio = 0.1
train_num = int(len(xml_list) * train_ratio)
val_num = int(len(xml_list) * val_ratio)
xml_list_train = xml_list[:train_num]
xml_list_val = xml_list[train_num: train_num + val_num]
xml_list_test = xml_list[train_num + val_num:]
# 将xml文件转为coco文件,在指定目录下生成三个json文件(train/test/food)
convert(xml_list_train, save_json_train)
convert(xml_list_val, save_json_val)
convert(xml_list_test, save_json_test)
print("train number:", len(xml_list_train))
print("val number:", len(xml_list_val))
print("test number:", len(xml_list_val))
二、训练自己的数据集
参考:
https://blog.csdn.net/qq_41613251/article/details/114446107
(1)在src/lib/datasets/dataset/下copy coco.py 新建一个c.py 文件,其具体修改如下图10个地方:
13行改为class c(data.Dataset): c为自己的类名
14行改为自己的类别数(不包括背景类)
15行改为自己的输入图像大小
16-19行可改为自己的均值方差,也可默认,求均值方差的脚本
22行修改为super(c,self)
23行self.data_dir = os.path.join(opt.data_dir, ‘c’)
24行修改’{}2017’为’images’
25行修改’test’为’val’,因为只用了只转了train.json test.json,没有val.json
28行’instances_extreme_{}2017.json’改为’test.json’
37行’instances_{}2017.json’改为’train.jsojn’
39行改为自己的类别[ ‘background’, ‘c’]
41行self._valid_ids = [1]
(2)修改src/lib/datasets/dataset_factory.py
22行添加自己的类别’c’: c
(3)修改 src/lib/opts.py
这两处必须改
15 self.parser.add_argument('--dataset', default='c',
16 help='coco | kitti | coco_hp | pascal | c')
338 'ctdet': {'default_resolution': [320, 320], 'num_classes': 1,
339 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278],
340 'dataset': 'c'}
下边修改学习率,修改backbone可改可不改
(4)修改src/lib/utils/debugger.py
48 elif num_classes == 1 or dataset == 'c':
49 self.names = c_class_name
444 c_class_name = ["c"]
使用nms时有个报错,在external编译完还是找不到文件。
from external.nms import soft_nms
ModuleNotFoundError: No module named ‘external.nms’
解决办法:
把nms.cpython-36m-x86_64-linux-gnu.so改成nms.so
三、nms和softnms的区别:
nms可以直接调接口
import torch
import torchvision
box = torch.tensor([[2,3.1,7,5],[3,4,8,4.8],[4,4,5.6,7],[0.1,0,8,1]])
score = torch.tensor([0.5, 0.3, 0.2, 0.4])
output = torchvision.ops.nms(boxes=box, scores=score, iou_threshold=0.3)
print('IOU of bboxes:')
iou = torchvision.ops.box_iou(box,box)
print(iou)
print(output)
四、后处理解析
ret=detector.run(img_name)
results[img_id]=ret[“results”]
ret[“results”]是dict:9
(9,5)
(18,5)
(5,5)
(30,5)
(16,5)
(4,5)
(10,5)
(4,5)
(4,5)