Faster-RCNN系列 二(Datasets代码,python)
数据集部分
定义数据集部分是继承Dataset,自己主要实现__len__和__getitem__模块;
定义__init__模块,主要目的是定义数据地址,将数据初始化,__init__模块主要包含图像,xml文件,train_txt和test_txt文件,分类的json文件的路径
def __init__(self, voc_root, transforms, txt_name: str = "train.txt"):
self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
# read train.txt or val.txt file
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines()]
# check file
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
for xml_path in self.xml_list:
assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
# read class_indict
try:
json_file = open('./pascal_voc_classes.json', 'r')
self.class_dict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
self.transforms = transforms
定义__len__函数,主要目的是返回数据长度,xml文件包含了全部的数据,直接返回长度就好了
def __len__(self):
return len(self.xml_list)
定义__getitem__模块,主要将图像数据与标注好的target信息返回,此处需要实现parse_xml_to_dic方法,主要目的是将xml文件中数据以字典形式存储起来
将单个图片信息xml文件拿出来,以etree的形式去读取xml文件中的所有数据
def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = []
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
实现parse_xml_to_dic方法,将xml文件的信息以字典的形式返回
def parse_xml_to_dict(self, xml):
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree
class VOC2012DataSet(Dataset):
def __int__(self, voc_root, transforms, txt_name:str = "train.txt"):
self.root = os.path.join(voc_root,"VOCdevkit", "VOC2012")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotation_root = os.path.join(self.root, "Annotations")
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
self.xml_list = [os.path.join(self.annotation_root, line.split() +'.xml')
for line in read.readlines()]
assert len(self.xml_list) > 0,"in '{}' file does not found any information.".format(txt_path)
for xml_path in self.xml_list:
assert os.path.exists(xml_path), "not found '{}' file".format(xml_path)
try:
json_file = open('./pascal_voc_class.json', 'r')
self.class_dict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
self.transforms = transforms
class VOC2012DataSet(Dataset):
"""读取解析PASCAL VOC2012数据集"""
def __init__(self, voc_root, transforms, txt_name: str = "train.txt"):
self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
# read train.txt or val.txt file
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines()]
# check file
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
for xml_path in self.xml_list:
assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
# read class_indict
try:
json_file = open('./pascal_voc_classes.json', 'r')
self.class_dict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
self.transforms = transforms
def __len__(self):
return len(self.xml_list)
def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = []
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def get_height_and_width(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}