SSD模型:https://github.com/amdegroot/ssd.pytorch
原因:当前的pytorch版本过高,而原代码的版本较低。如果pytorch版本高于1.3会出现该问题。当前版本要求forward过程是静态的,所以需要将原代码进行修改。
解决方法:参考1替换的代码文件地址 Github 镜像仓库 源项目地址 ⬇ ⬇https://gitcode.net/mirrors/sayakbanerjee1999/Single-Shot-Object-Detection-Updated?utm_source=csdn_github_accelerator
如何使用自己的数据集进行训练:参考2 参考3
修改ssd.py中的代码forward方法红色框中的【21】改为【self.num_classes】:(PS:该ssd.py代码为问题1中解决后修改的ssd.py)
原因:annotation也就是xml文件里面有些包含空目标
解决:使用下述代码,检查哪些xml没有标注目标,然后重新打标签
import argparse
import sys
import cv2
import os
import os.path as osp
import numpy as np
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
parser = argparse.ArgumentParser(
description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--root', default=os.path.join(get_dataset_path(), "VOCdevkit", "VOC2021"), help='Dataset root directory path')
args = parser.parse_args()
CLASSES = ( # always index 0
'thrips', 'scarab_beetles', 'red_spiders', 'spodoptera_litura_fabricius',
'epicauta_ruficeps')
annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
imgpath = osp.join('%s', 'JPEGImages', '%s.{}'.format("jpg"))
problem = []
def voc_checker(image_id, width, height, keep_difficult=False):
"""
检查是否有xml没有标注目标
Args:
image_id:
width:
height:
keep_difficult:
Returns:
"""
target = ET.parse(annopath % image_id).getroot()
res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# scale height or width
cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height
bndbox.append(cur_pt)
print("name:{}".format(name))
label_idx = dict(zip(CLASSES, range(len(CLASSES))))[name]
bndbox.append(label_idx)
res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
# img_id = target.find('filename').text[:-4]
print("res:{}".format(res))
try:
print(np.array(res)[:,4])
print(np.array(res)[:,:4])
except IndexError:
problem.append(image_id)
print("\nINDEX ERROR HERE !\n")
# exit(0)
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
if __name__ == '__main__':
i = 0
for name in sorted(os.listdir(osp.join(args.root, 'Annotations'))):
# as we have only one annotations file per image
i += 1
img = cv2.imread(imgpath % (args.root, name.split('.')[0]))
height, width, channels = img.shape
print("path : {}".format(annopath % (args.root,name.split('.')[0])))
res = voc_checker((args.root, name.split('.')[0]), height, width)
print("Total of annotations : {}".format(i))
print(problem)
解决:在/data/VOCdevkit/annotations_cache/ 中删掉annots.pkl