Faster-RCNN代码解读3:制作自己的数据加载器

Faster-RCNN代码解读3:制作自己的数据加载器

前言

​ 因为最近打算尝试一下Faster-RCNN的复现,不要多想,我还没有厉害到可以一个人复现所有代码。所以,是参考别人的代码,进行自己的解读。

代码来自于B站的UP主(大佬666),其把代码都放到了GitHub上了,我把链接都放到下面了(应该不算侵权吧,毕竟代码都开源了_):

b站链接:https://www.bilibili.com/video/BV1of4y1m7nj/?vd_source=afeab8b555e5eb1bfa1e7f267262cbf2

GitHub链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing

目的

​ 其实UP主已经做了很好的视频讲解了他的代码,只是有时候我还是喜欢阅读博客来学习,另外视频很长,6个小时,我看的时候容易睡着_,所以才打算写博客记录一下学习笔记。

目前完成的内容

第一篇:VOC数据集详细介绍

第二篇:Faster-RCNN代码解读2:快速上手使用

第三篇:Faster-RCNN代码解读3:制作自己的数据加载器(本文)

目录结构

文章目录

    • Faster-RCNN代码解读3:制作自己的数据加载器
      • 1. 前言:
      • 2. my_dataset.py文件解读:
        • 2.1 init方法:
        • 2.2 len方法:
        • 2.3 getitem方法:
        • 2.4 辅助方法:get_height_and_width
        • 2.5 辅助方法:parse_xml_to_dict
        • 2.6 辅助方法:coco_index
      • 3. 总结:

1. 前言:

​ 其实这个部分还是比较简单的(如果你看过我前面的图像分类加载器实现或者自己实现过),就是定义一个dataset类。

2. my_dataset.py文件解读:

​ 我们知道,想要定义自己的dataset类,首先需要继承于torch的Dataset类,并且至少需要定义三个方法,即__init____len____getitem__

​ 那么,可以写出大体框架:

class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""

    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
		pass

​ 好的,下面我们来一一实现。

2.1 init方法:

​ 首先,需要定义我们的输入参数,这里如果是自己从头实现的话,估计需要想到什么参数用参数。但是,我们解读的话,就直接看作者定义了哪些参数:

  • voc_root: 数据集所在的根目录
  • year: 指定读取2007还是2012的数据集,默认为2012
  • transforms: 预处理方法,默认为None
  • txt_name: 指定加载训练集还是测试集,默认为训练集,即train.txt

​ 接下来,第一步,增加一下代码的容错能力,就是判断一下传入的参数正不正确,并拼接出需要的路径:

# 判断是不是2007或2012,否则报错
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
    # 如果传入的参数为:.\VOCdevkit,那么直接拼接为.\VOCdevkit\VOC2012
    self.root = os.path.join(voc_root, f"VOC{year}")
else:
    # 如果传入的参数为:. ,那么直接拼接为.\VOCdevkit\VOC2012
    self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
# 拼接路径,即图片路径和注释路径
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")

​ 第二步,读取数据集.\VOCdevkit\VOC2012\ImageSets\Main里面的训练集或测试集txt文件(如果你不知道这里面为什么的话,可以看第一篇文章,VOC数据集介绍),并将里面的值和后缀xml拼接为训练集或测试集的注释文件:

# 读取train或者val文件
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
# 然后,将文件名(2007_000027)和后缀拼接在一起,这样才是真实的文件
with open(txt_path) as read:
    xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

​ 第三步,需要一一读取xml文件,并将里面的内容转为字典值,主要目的是检查一下xml文件是否有问题:

# 定义真正的xml列表
self.xml_list = []
# 检测所有xml文件是否存在并读取内容
for xml_path in xml_list:
    if os.path.exists(xml_path) is False:
        print(f"Warning: not found '{xml_path}', skip this annotation file.")
        continue
    # 如果xml文件存在,继续下面的代码
    # check for targets
    # 读取xml文件
    with open(xml_path) as fid:
    	xml_str = fid.read()
    # 构建xml对象
    xml = etree.fromstring(xml_str)
    # 获取节点的内容,并转为字典值
    data = self.parse_xml_to_dict(xml)["annotation"] # 获取annotation节点下的所有内容
    if "object" not in data: # 判断object节点是否存在,如果不存在说明xml文件其实有问题,所以需要跳过
        print(f"INFO: no objects in {xml_path}, skip this annotation file.")
        continue
    # 添加
    self.xml_list.append(xml_path)

​ 第四步,加载类别json文件,并读取里面的内容:

# 读取类别文件,一共20个类,从1开始是因为0留给背景
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
	self.class_dict = json.load(f)

​ 最后,将预处理函数放入一个变量中:

self.transforms = transforms

​ **总结一下:**经过上面的处理,我们得到了几个主要的变量:

  • self.xml_list:里面的值为一个个训练集或测试集的xml文件,里面的值为文件路径值
  • self.transforms:里面为我们的预处理方法
  • self.class_dict:为我们的类别字典,里面的值为{‘preson’:2}这样的形式

​ 给大家看看,debug下的值的内容:

在这里插入图片描述

2.2 len方法:

​ len方法,这个是最简单的方法,其作用就是返回长度值:

def __len__(self):
    # len函数就是返回长度
    return len(self.xml_list)

2.3 getitem方法:

​ 这个方法和init方法一样十分重要,其作用就是获取图像和图像对应的标签等信息。

def __getitem__(self, idx):
	pass

​ 其中idx是这个方法必备的一个参数,其是随机返回一个索引值,来方便你取你之前在init方法定义的变量里的值。

​ 那么,首先,获取一个xml文件,并打开它获取根节点里面的内容:

# 随机读取一个xml文件
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
	xml_str = fid.read()
# 创建xml对象
xml = etree.fromstring(xml_str)
# 获取根节点,转为字典值
data = self.parse_xml_to_dict(xml)["annotation"]

​ 这里解释一下上面的data值为啥。其实就是xml文件annotation节点里的所有内容,如下图框出来的内容:

Faster-RCNN代码解读3:制作自己的数据加载器_第1张图片

​ 当然,同样用debug看看里面真实情况下的值:

在这里插入图片描述

​ 然后,**我们知道xml文件名和图片名是对应的,**因此通过xml文件获取图片名字并打开这个图像:

# 获取xml文件对应的图像路径
img_path = os.path.join(self.img_root, data["filename"])
# 打开图像
image = Image.open(img_path)
# 判断图像是否为jpeg格式,主要作者防止别人插入了其它的文件
if image.format != "JPEG":
	raise ValueError("Image '{}' format not JPEG".format(img_path))

​ 接着,初始化一些变量:

# 初始化一些变量
boxes = []		# 边界框
labels = []		# 标签值
iscrowd = []	# 是否为难以识别的图像

下面开始是最重要的内容

​ 首先,迭代读取xml文件object节点下的内容:

# 读取xml文件中object节点下的内容
for obj in data["object"]:

​ 其中的,obj为下图中的值:

在这里插入图片描述

​ 或者可以从xml文件中对应查看:

Faster-RCNN代码解读3:制作自己的数据加载器_第2张图片

​ 接着,获取对象的真实边界框的坐标值(左上角,右下角):(ps:下面的代码都是放在上面的for循环里面的)

# 获取bbox框的坐标
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])

​ 检测一下,边界框是否有问题:

# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
    print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
    continue

​ 然后,把坐标值加入boxes变量中,把标签加入labels变量中,并判断图像是否为难以识别的,然后加入iscrowd变量中:

boxes.append([xmin, ymin, xmax, ymax])
# 添加标签  obj["name"]=person,  self.class_dict[obj["name"]] = 15
labels.append(self.class_dict[obj["name"]])
# 判断是否为difficult类型
if "difficult" in obj:
    iscrowd.append(int(obj["difficult"]))
    else:
        iscrowd.append(0)

​ 然后,把所有的变量类型都转为tensor格式(此时已经结束了循环):

# 将所有的类型转为tensor类型
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])

​ 接着,根据边框框的四个坐标,计算一下边界框的面积,主要方便后期计算IOU:

#  boxes =[[,,,],[,,,],。。。。。。]
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# (ymax - ymin) * (xmax - xmin) ,即框的面积

​ 最后,把上面的所有值放入一个字典变量中即可:

# 把这些东西放入一个字典中
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

​ 然后,对图像进行预处理并返回图像和其对应的值即可:

# 变换,此时为自己实现的方法,不是官方的方法
if self.transforms is not None:
	image, target = self.transforms(image, target)
return image, target

​ 最后,我们在debug下看看变量的值:

在这里插入图片描述

2.4 辅助方法:get_height_and_width

​ 作用:获取图像的宽和高。

​ 这个十分简单,就是通过xml文件来获取的,还不需要我们自己通过坐标计算:

def get_height_and_width(self, idx):
    # 获取图像的宽和高
    # 读取xml
    xml_path = self.xml_list[idx]
    with open(xml_path) as fid:
		xml_str = fid.read()
    # 构建xml对象
    xml = etree.fromstring(xml_str)
    # 获取根节点
    data = self.parse_xml_to_dict(xml)["annotation"]
    # 获取宽和高
    data_height = int(data["size"]["height"])
    data_width = int(data["size"]["width"])
    return data_height, data_width

2.5 辅助方法:parse_xml_to_dict

​ 主要作用:将xml格式的数据解析为字典格式,即将节点-----节点的值,转为{‘节点’:‘节点的值’}。

​ 这个方法是通过递归来实现的,这个没什么好说的,如果你想搞清楚如何运行的,可以自己一步一步的推导:

def parse_xml_to_dict(self, xml):
    """
    将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
    """

    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        # xml.tag节点名字
        # xml.text里面的值
        return {xml.tag: xml.text}

    result = {}
    # 对于每个xml中的子节点
    for child in xml:
        child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
	        result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

2.6 辅助方法:coco_index

这个方法与getitem方法是相同的作用,只是不读取图片,流程都是一样的,我就不细说了。

3. 总结:

​ my_dataset.py文件主要实现了数据加载器的类,实现思路很简单,但是代码量还是比较大的。

​ 另外,作者在该文件的末尾展示了一下这个类的使用示例代码,大家可以直接把注释取消运行看看结果:

Faster-RCNN代码解读3:制作自己的数据加载器_第3张图片

你可能感兴趣的:(Faster-RCNN代码复现,深度学习,python,人工智能)