介绍写一个属于自己的DataSet,通过这个数据集我们就可以去读取PASCAL VOC的数据,或者读取自己数据集的数据。
其实这个比较简单,通过运行脚本split_data.py
可以生成train.txt和val.txt
import os
import random
def main():
random.seed(0) # 设置随机种子,保证随机结果可复现
# .xml格式的标注文件的根目录或者保存的.jpg图片的根目录
files_path = "./VOCdevkit/VOC2012/Annotations"
assert os.path.exists(files_path), "path: '{}' does not exist.".format(files_path)
# 定义验证集的比率
val_rate = 0.5
files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
files_num = len(files_name)
val_index = random.sample(range(0, files_num), k=int(files_num*val_rate))
train_files = []
val_files = []
for index, file_name in enumerate(files_name):
if index in val_index:
val_files.append(file_name)
else:
train_files.append(file_name)
try:
train_f = open("train.txt", "x")
eval_f = open("val.txt", "x")
train_f.write("\n".join(train_files))
eval_f.write("\n".join(val_files))
except FileExistsError as e:
print(e)
exit(1)
if __name__ == '__main__':
main()
其中,files_path
为.xml格式的标注文件的根目录或者保存的.jpg图片的根目录;
对于标注文件夹Annotations
下包含了所有.xml标注文件,格式如下
运行split_data.py
文件,就可以生成划分好的train.txt
和val.txt
代码参考pytorch官方文档
自定义数据集需要继承标准的torch.utils.data.Dataset
类,并且需要实现两个方法:
__len__
和__getitem__
,通过__len__
方法可以获取数据样本的数量,通过__item__
方法可以返回图片和其所对应的信息。
后面也提到一个获取图片宽度和高度的方法get_height_and_width
,这个方法在调用多GPU训练的脚本时,那么就需要这个方法,如果不适用这个方法,就会载入所有图片去计算它的高和宽,这样的话比较耗时和占内存。因此我们需要提前去实现get_height_and_width
这样就不会去遍历整个数据集了。
首先需要继承torch.utils
中的Dataset类
class VOCDataSet(Dataset):
__init__
方法初始化继承Dataset类,并进行初始化。在初始化 __init__
方法中,传入了三个参数,分别为voc_root
、transforms
、train_set
。
其中:
voc_root
表示训练集所在的根目录,假设训练集放在项目的根目录,则voc_roo="./"
;train_set
它是一个bool值,如果train_set=True
则会返回train.txt
对应的所有数据,如果为False
则对应为验证集,则返回val.txt
对应的所有数据初始化方法中,获取了图片、标注文件的根目录:self.root
和self.annotation
,并获取返回数据集的xml文件列表self.xml_list
。
代码
以VOC数据集为例,对于自定义数据集,只需要修改对应的代码即可
class VOCDataSet(Dataset):
"""读取解析PASCAL VOC2007/2012数据集"""
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
self.root = os.path.join(voc_root, f"VOC{year}")
else:
self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
# read train.txt or val.txt file
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines() if len(line.strip()) > 0]
self.xml_list = []
# check file
for xml_path in xml_list:
if os.path.exists(xml_path) is False:
print(f"Warning: not found '{xml_path}', skip this annotation file.")
continue
# check for targets
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
if "object" not in data:
print(f"INFO: no objects in {xml_path}, skip this annotation file.")
continue
self.xml_list.append(xml_path)
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
# read class_indict
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
self.class_dict = json.load(f)
self.transforms = transforms
其中./pascal_voc_classes.json
为类别信息,对应pascal voc对应的类别信息如下:
{
"aeroplane": 1,
"bicycle": 2,
"bird": 3,
"boat": 4,
"bottle": 5,
"bus": 6,
"car": 7,
"cat": 8,
"chair": 9,
"cow": 10,
"diningtable": 11,
"dog": 12,
"horse": 13,
"motorbike": 14,
"person": 15,
"pottedplant": 16,
"sheep": 17,
"sofa": 18,
"train": 19,
"tvmonitor": 20
}
自定义数据集,除了基础Dataset标准类,还需要实现__len__
和__getitem__
方法,我这里还额外实现了get_height_and_width
.
__len__
方法代码
def __len__(self):
return len(self.xml_list)
__getitem__
方法__getitem__
方法需要传入索引值idx
参数,根据idx
返回索引对应的图片及图片信息。
流程:
idx
获取对应的xml文件参考tensorflow官方文档的recursive_parse_xml_to_dict
)def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image '{}' format not JPEG".format(img_path))
boxes = []
labels = []
iscrowd = []
assert "object" in data, "{} lack of object information.".format(xml_path)
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
if "difficult" in obj:
iscrowd.append(int(obj["difficult"]))
else:
iscrowd.append(0)
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
__getitem
方法,返回对应的image以及图片信息target
,其中target
为一个字典,包含图片的标签labels
,图片的idimage_id
,图片的面积area
以及目标识别难以程度iscrowd
信息。
标注好的xml文件如下
<annotation>
<folder>VOC2012folder>
<filename>2007_000042.jpgfilename>
<source>
<database>The VOC2007 Databasedatabase>
<annotation>PASCAL VOC2007annotation>
<image>flickrimage>
source>
<size>
<width>500width>
<height>335height>
<depth>3depth>
size>
<segmented>1segmented>
<object>
<name>trainname>
<pose>Unspecifiedpose>
<truncated>1truncated>
<difficult>0difficult>
<bndbox>
<xmin>263xmin>
<ymin>32ymin>
<xmax>500xmax>
<ymax>295ymax>
bndbox>
object>
<object>
<name>trainname>
<pose>Unspecifiedpose>
<truncated>1truncated>
<difficult>0difficult>
<bndbox>
<xmin>1xmin>
<ymin>36ymin>
<xmax>235xmax>
<ymax>299ymax>
bndbox>
object>
annotation>
其中利用parse_xml_to_dict
函数对xml将xml文件解析为字典dict
参考tensorflow官方文档的recursive_parse_xml_to_dict
def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
如果transforms is not None
,则需要对图片image
进行变换,同样图片的信息target
也会跟着变换。对应代码在transforms.py
文件中,主要调用了ToTensor
和RandomHorizontalFlip
类,代码如下:
import random
from torchvision.transforms import functional as F
class Compose(object):
"""组合多个transform函数"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class ToTensor(object):
"""将PIL图像转为Tensor"""
def __call__(self, image, target):
image = F.to_tensor(image)
return image, target
class RandomHorizontalFlip(object):
"""随机水平翻转图像以及bboxes"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1) # 水平翻转图片
bbox = target["boxes"]
# bbox: xmin, ymin, xmax, ymax
bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息
target["boxes"] = bbox
return image, target
其中ToTensor
类比较简单,直接用pytorch提供的方法进行转换。我这边主要讲下随机翻转RandomHorizontalFlip
。
在正向传播过程中,会输入image
和图片信息target
.并随机概率如果小于设置的概率(如0.5)则对图像进行翻转。其中翻转后,bounding box
的坐标信息也需要跟着翻转,变换如下:
# bbox: xmin, ymin, xmax, ymax
bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息
target["boxes"] = bbox
水平翻转即图片在水平方向进行调换
, 经过翻转后原始图片左上角的 ( X m i n , Y m i n ) (X_{min},Y_{min}) (Xmin,Ymin)就跑到它的右上角去了,原始图片右上角的 ( X m a x , Y m a x ) (X_{max},Y_{max}) (Xmax,Ymax)就跑到它的左下角这个地方。
对于翻转之后新的左上角点 ( X 左 上 ′ , Y 左 上 ′ ) (X_{左上}',Y_{左上}') (X左上′,Y左上′),它的 Y 左 上 ′ Y_{左上}' Y左上′值和变换前是一样的,变的只是 X 左 上 ′ X_{左上}' X左上′,此时 X 左 上 ′ X_{左上}' X左上′=width-X_{max}。
同理右下角左上角点 ( X 右 下 ′ , Y 右 下 ′ ) (X_{右下}',Y_{右下}') (X右下′,Y右下′),它的 Y 右 下 ′ Y_{右下}' Y右下′值和变换前是一样的,变的只是 X 右 下 ′ X_{右下}' X右下′,此时 X 右 下 ′ X_{右下}' X右下′=width-X_{min}。
因此有如下的变换代码
bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息
预处理之后就可以返回图像image
和图像信息target
了
get_height_and_width
方法 def get_height_and_width(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
这个比较简单,解析后直接获取height
和width
。
以上就完成了自定义dataset类编写,完整的代码如下:
import numpy as np
from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree
class VOCDataSet(Dataset):
"""读取解析PASCAL VOC2007/2012数据集"""
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
self.root = os.path.join(voc_root, f"VOC{year}")
else:
self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
# read train.txt or val.txt file
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines() if len(line.strip()) > 0]
self.xml_list = []
# check file
for xml_path in xml_list:
if os.path.exists(xml_path) is False:
print(f"Warning: not found '{xml_path}', skip this annotation file.")
continue
# check for targets
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
if "object" not in data:
print(f"INFO: no objects in {xml_path}, skip this annotation file.")
continue
self.xml_list.append(xml_path)
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
# read class_indict
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
self.class_dict = json.load(f)
self.transforms = transforms
def __len__(self):
return len(self.xml_list)
def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"])
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image '{}' format not JPEG".format(img_path))
boxes = []
labels = []
iscrowd = []
assert "object" in data, "{} lack of object information.".format(xml_path)
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
if "difficult" in obj:
iscrowd.append(int(obj["difficult"]))
else:
iscrowd.append(0)
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def get_height_and_width(self, idx):
# read xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
def parse_xml_to_dict(self, xml):
"""
将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
Args:
xml: xml tree obtained by parsing XML file contents using lxml.etree
Returns:
Python dictionary holding XML contents.
"""
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}