本文以 Pascal VOC 2012 数据集为例,讲解如何自定义一个可以用于目标检测的数据集
参考 Pytorch 官方提供的样例:Tutorial
训练对象检测、实例分割和人物关键点检测的参考脚本可以轻松支持添加新的自定义数据集。数据集应该继承于标准的 torch.utils.data.Dataset 类,并实现 __len__
和 __getitem__
方法
__len__
:图片的数量__getitem__
:图片及其对应的信息get_height_and_width
:获取图像高度和宽度的方法
定义一个继承 Dataset 的类
from torch.utils.data import Dataset
class VOC2012DataSet(Dataset):
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
voc_root
:训练集所在的根目录transforms
:图像预处理txt_name
:在后面判断是调用训练集数据还是验证集数据 assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
self.root = os.path.join(voc_root, f"VOC{year}")
else:
self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
通过 os.path.join
方法将各个目录与根目录拼接到一起(os.path
可用理解为是路径中的斜杠,可以根据不同的操作系统适应不同方向的斜杠)
# read train.txt or val.txt file
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines() if len(line.strip()) > 0]
根据传入的信息判断对应读取的是训练集还是验证集,并通过 os.path.join
构建对应的文件路径
其中train.txt
和 val.txt
中包含的对应的文件的名称,且每一行最后都有一个换行符。通过 for
循环遍历所有文件名,使用 line.strip()
方法,去掉最后的换行符,并给每一个文件名后面添加 .xml
的后缀,得到每一个文件对应的 xml 文件,并将所有 xml 文件保存到一个列表中
# read class_indict
json_file = '/od/model/faster_rcnn/pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
self.class_dict = json.load(f)
self.transforms = transforms
载入之前设定好的类别名称及其对应的索引信息,并存给 self.class_dict
找个变量
json 文件内容如下所示:
{
"aeroplane": 1,
"bicycle": 2,
"bird": 3,
"boat": 4,
"bottle": 5,
"bus": 6,
"car": 7,
"cat": 8,
"chair": 9,
"cow": 10,
"diningtable": 11,
"dog": 12,
"horse": 13,
"motorbike": 14,
"person": 15,
"pottedplant": 16,
"sheep": 17,
"sofa": 18,
"train": 19,
"tvmonitor": 20
}
def __len__(self):
return len(self.xml_list)
通过 len()
方法获取文件的个数,变量 xml_list 中存储着所有的 .xml 文件,而一个 .xml 文件就对应着一张图片,所以 len()
方法返回的就是数据集的个数
def __getitem__(self, idx):
idx
:索引值,通过索引值返回图片及图片对应的信息 # read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image '{}' format not JPEG".format(img_path))
找到 idx 对应的 xml 文件并打开,通过 etree.fromstring()
读取 xml 文件并将 xml 文件中的信息传入方法 parse_xml_to_dict
(将 xml 文件解析成字典形式进行存储)
在使用之前需要判断图片文件的类型是否是 .jpeg
格式(如果使用的是 VOC 数据集的话将不会有什么影响)
boxes = []
labels = []
iscrowd = []
assert "object" in data, "{} lack of object information.".format(xml_path)
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
if "difficult" in obj:
iscrowd.append(int(obj["difficult"]))
else:
iscrowd.append(0)
boxes
:保存每一个目标的 bounding box 的信息
labels
:存储对应目标的索引值(.json
文件中设定的)
iscrowd
:存在于 COCO 数据集中(判断目标是否重合)
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
将 xml 文件解析成字典形式,参考 tensorflow 的 recursive_parse_xml_to_dict
简单来说,xml 就是嵌套的字典
def parse_xml_to_dict(self, xml):
if len(xml) == 0: # 遍历到底层,直接返回 tag 对应的信息
return {xml.tag: xml.text}
result = {]
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为 object 可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
深度优先遍历
获取图像的高度和宽度
def get_height_and_width(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
通过以上文件代码,我们可以在训练文件中,可以通过 torch.utils.data.DataLoader
传入数据并进行训练