FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集

读取解析PASCAL VOC2012数据集


文章目录

  • 读取解析PASCAL VOC2012数据集
  • 前言
  • 一、认识数据集
    • 1. Annotations文件夹
    • 2.JPEGImages文件夹
    • 3.ImageSets文件夹
    • 4. pascal_voc_classes.json文件
  • 二、代码解析
    • 1.`__init__`
    • 2.`__len__`
    • 3.`__getitem__`
    • 4.`get_height_and_width`
    • 5.`parse_xml_to_dict`
    • 6.`collate_fn`
  • 三、应用代码
  • 总结


前言

在我们训练网络的过程中,读取解析我们的数据集往往是第一步,由于本人刚开始学习深度学习机器学习不久,初次阅读该部分的代码里理解可能有些偏差,特将此记录下来,以便有新的理解和感悟。
本人将视频Faster RCNN源码解析(pytorch)p2进行一个总结整理,为以后的读取自己的数据集打下基础

一、认识数据集

在我们VOC2012文件夹中,主要由如下几个主要的文件夹,共1.83 GB
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第1张图片

1. Annotations文件夹

其中Annotations文件夹中包含:17125个xml文件
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第2张图片
我们打开其中一个xml文件,其主要信息有1.所属文件夹 2.文件名 3.文件源 4.图片大小 5.图片是否完整 6.图片中包含那几个类别(其中包括类别的名字,box的坐标等等)
我们需要将其中的有效信息提取出来,并与图片对应起来。
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第3张图片

2.JPEGImages文件夹

该文件夹下就是我们上述xml文件所对应的图片信息。

FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第4张图片

3.ImageSets文件夹

其主要存储的是各个类别的文件名,训练集的文件名等
如在其main文件夹下的train.txt文件
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第5张图片

4. pascal_voc_classes.json文件

该文件与读取解析PASCAL VOC2012数据集的代码在同一文件夹下,其目的在于使20个类别转化为数字类别,由于PASCAL VOC2012数据集有二十个类别,因此类别aeroplane为类别1,bicycle为2,以此类推。。。
之所以不从0开始是因为在网络分类的过程中还有背景的类别,默认为类别0。
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第6张图片

二、代码解析

定义VOC2012DataSet类后,有5个方法,接下来依次介绍

1.__init__

其主要功能为

  1. 获取JPEGImagesAnnotations文件家的路径
  2. 获取VOCdevkit/VOC2012/ImageSets/Main/train.txt下的所有文件名
  3. 合成xml_list(其中代表的是train.txt里的文件名对应在Annotations的xml文件)
    FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第7张图片
  4. 读取pascal_voc_classes.json文件

代码如下:

from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree


class VOC2012DataSet(Dataset):
    """读取解析PASCAL VOC2012数据集"""

    def __init__(self, voc_root, transforms, txt_name: str = "train.txt"):
        # voc_root:训练集所在的根目录, transforms:预处理方法, txt_name: str = "train.txt":返回train.txt中的所有数据
        self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # read train.txt or val.txt file
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                             for line in read.readlines()]

        # read class_indict
        try:
            json_file = open('./pascal_voc_classes.json', 'r')
            self.class_dict = json.load(json_file)
        except Exception as e:
            print(e)
            exit(-1)

        self.transforms = transforms

2.__len__

获取 变量xml_list的长度,即训练集样本的数量
代码如下:

    def __len__(self):
        return len(self.xml_list)

3.__getitem__

获取 样本图片和标签

  1. 获取路径和标签信息
    通过传入idx即索引,来获取该索引下的对应图片的xml文件所对应的路径。(其目的是为了获取标签数据)
    采用parse_xml_to_dict方法对xml文件进行读取,存在变量data中,并以字典的形式进行存储,
  2. 获取图片和打包标签
    变量data中有图片的名称,以及各个目标框的位置即标签(这里需要注意的是从字典中读取的信息都是以字符形式存储的,我们需要将其转化为浮点型,并将其转化为tensor格式

我们最后返回的是imagetarget(已经进行了数据预处理)
在这里插入图片描述
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第8张图片

代码如下:

      def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str) # etree包 读取xml文件
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image format not JPEG")
        boxes = []
        labels = []
        iscrowd = []
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"]) # 字符型变量转化为浮点型变量
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            iscrowd.append(int(obj["difficult"]))

        # convert everything into a torch.Tensor  将这几个列表转化为tensor格式
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

4.get_height_and_width

获取 传入对应索引值的图片的高和宽,其操作步骤和上一个方法一样,同样需要将字符型转化为数字类型
返回图片的data_height, data_width
在 多GPU训练时需要这个方法,若没有这个方法,则会自动载入图片去计算图片的高和宽,会耗时也会占内存。
代码如下:

    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"]) # 将字符型转化为数字类型
        data_width = int(data["size"]["width"])
        return data_height, data_width

5.parse_xml_to_dict

将xml文件解析成字典形式
代码如下:

    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

6.collate_fn

这个是我们自己传入的分类方法,若不传入,其默认通过torch.stack()进行简单的拼接,因为这里并不是tensor格式,而是 元组tuple类型,该函数的目的是将 img和target 各自打包放在一起

    def collate_fn(batch):
        return tuple(zip(*batch))

batch原类型的数据为
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第9张图片
经过collate_fn函数之后返回
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第10张图片


三、应用代码

这一部分的内容主要是对前面定义的 VOC2012DataSet类 进行一个应用
我们先定义一个名为category_index的空字典。

然后打开pascal_voc_classes.json文件,并将索引作为键,类别名作为值传入category_index空字典中

import transforms
from draw_box_utils import draw_box
from PIL import Image
import json
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import random

# read class_indict
category_index = {}
try:
    json_file = open('./pascal_voc_classes.json', 'r')
    class_dict = json.load(json_file)
    category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
    print(e)
    exit(-1)

FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第11张图片
定义 数据的预处理过程()

data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(0.5)]),
    "val": transforms.Compose([transforms.ToTensor()])
}

实例化VOC2012DataSet类 对象

# load train data set
train_data_set = VOC2012DataSet(os.getcwd(), data_transform["train"])
a = len(train_data_set)

FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第12张图片
然后随机采样5张图片,通过VOC2012DataSet类__getitem__方法返回img,target
然后通过ts.ToPILImage()(img)方法由tensor格式返回为PILImage格式
然后显示图像

for index in random.sample(range(0, len(train_data_set)), k=5):
    img, target = train_data_set[index]
    img = ts.ToPILImage()(img) # 由tensor格式返回为PILImage格式
    draw_box(img,
             target["boxes"].numpy(),
             target["labels"].numpy(),
             [1 for i in range(len(target["labels"].numpy()))],
             category_index,
             thresh=0.5,
             line_thickness=1)
    plt.imshow(img)
    plt.show()

最后输出为
FasterRCNN源码解析(二)——读取解析PASCAL VOC2012数据集_第13张图片

。。。
共5张图片

总结

对于读取解析数据的以上几个方法,有了清晰明了的一个认识,至于去如何使用还需要与fasterRCNN的源码配合使用,接下来的这几天我将会对fasterRCNN的源码进行学习,争取吃透,能够举一反三。

你可能感兴趣的:(计算机视觉,python,pytorch,神经网络)