目标检测——自定义DataSet类

介绍写一个属于自己的DataSet,通过这个数据集我们就可以去读取PASCAL VOC的数据,或者读取自己数据集的数据。

1.根据数据集如何生成train.txt 和val.txt

其实这个比较简单,通过运行脚本split_data.py可以生成train.txt和val.txt

import os
import random


def main():
    random.seed(0)  # 设置随机种子,保证随机结果可复现
	# .xml格式的标注文件的根目录或者保存的.jpg图片的根目录
    files_path = "./VOCdevkit/VOC2012/Annotations"
    assert os.path.exists(files_path), "path: '{}' does not exist.".format(files_path)
	# 定义验证集的比率
    val_rate = 0.5

    files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
    files_num = len(files_name)
    val_index = random.sample(range(0, files_num), k=int(files_num*val_rate))
    train_files = []
    val_files = []
    for index, file_name in enumerate(files_name):
        if index in val_index:
            val_files.append(file_name)
        else:
            train_files.append(file_name)

    try:
        train_f = open("train.txt", "x")
        eval_f = open("val.txt", "x")
        train_f.write("\n".join(train_files))
        eval_f.write("\n".join(val_files))
    except FileExistsError as e:
        print(e)
        exit(1)


if __name__ == '__main__':
    main()

其中,files_path为.xml格式的标注文件的根目录或者保存的.jpg图片的根目录;
对于标注文件夹Annotations下包含了所有.xml标注文件,格式如下
目标检测——自定义DataSet类_第1张图片
运行split_data.py文件,就可以生成划分好的train.txtval.txt

2.如何构建属于自己的DataSet类

代码参考pytorch官方文档
目标检测——自定义DataSet类_第2张图片自定义数据集需要继承标准的torch.utils.data.Dataset 类,并且需要实现两个方法:
__len____getitem__,通过__len__方法可以获取数据样本的数量,通过__item__方法可以返回图片和其所对应的信息。

目标检测——自定义DataSet类_第3张图片
后面也提到一个获取图片宽度和高度的方法get_height_and_width,这个方法在调用多GPU训练的脚本时,那么就需要这个方法,如果不适用这个方法,就会载入所有图片去计算它的高和宽,这样的话比较耗时和占内存。因此我们需要提前去实现get_height_and_width这样就不会去遍历整个数据集了。

2.1 继承Dataset类

首先需要继承torch.utils中的Dataset类

class VOCDataSet(Dataset):

2.1 __init__方法初始化

继承Dataset类,并进行初始化。在初始化 __init__方法中,传入了三个参数,分别为voc_roottransformstrain_set

其中:

  • voc_root表示训练集所在的根目录,假设训练集放在项目的根目录,则voc_roo="./";
  • transform 表示数据预处理方法。
  • 对于参数train_set它是一个bool值,如果train_set=True则会返回train.txt对应的所有数据,如果为False则对应为验证集,则返回val.txt对应的所有数据

初始化方法中,获取了图片、标注文件的根目录:self.rootself.annotation,并获取返回数据集的xml文件列表self.xml_list

代码
以VOC数据集为例,对于自定义数据集,只需要修改对应的代码即可

class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""

    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        # 增加容错能力
        if "VOCdevkit" in voc_root:
            self.root = os.path.join(voc_root, f"VOC{year}")
        else:
            self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # read train.txt or val.txt file
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

        self.xml_list = []
        # check file
        for xml_path in xml_list:
            if os.path.exists(xml_path) is False:
                print(f"Warning: not found '{xml_path}', skip this annotation file.")
                continue

            # check for targets
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            if "object" not in data:
                print(f"INFO: no objects in {xml_path}, skip this annotation file.")
                continue

            self.xml_list.append(xml_path)

        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)

        # read class_indict
        json_file = './pascal_voc_classes.json'
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transforms

其中./pascal_voc_classes.json为类别信息,对应pascal voc对应的类别信息如下:

{
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20
}

自定义数据集,除了基础Dataset标准类,还需要实现__len____getitem__方法,我这里还额外实现了get_height_and_width.

2.2 __len__方法

代码

def __len__(self):
	return len(self.xml_list)

2.3 __getitem__方法

__getitem__方法需要传入索引值idx参数,根据idx返回索引对应的图片及图片信息。

流程:

  • 1.根据idx 获取对应的xml文件
  • 2.通过etree库,将xml文件解析成字典形式(参考tensorflow官方文档的recursive_parse_xml_to_dict
def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []
        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

__getitem方法,返回对应的image以及图片信息target,其中target为一个字典,包含图片的标签labels,图片的idimage_id,图片的面积area以及目标识别难以程度iscrowd信息。

标注好的xml文件如下

<annotation>
	<folder>VOC2012folder>
	<filename>2007_000042.jpgfilename>
	<source>
		<database>The VOC2007 Databasedatabase>
		<annotation>PASCAL VOC2007annotation>
		<image>flickrimage>
	source>
	<size>
		<width>500width>
		<height>335height>
		<depth>3depth>
	size>
	<segmented>1segmented>
	<object>
		<name>trainname>
		<pose>Unspecifiedpose>
		<truncated>1truncated>
		<difficult>0difficult>
		<bndbox>
			<xmin>263xmin>
			<ymin>32ymin>
			<xmax>500xmax>
			<ymax>295ymax>
		bndbox>
	object>
	<object>
		<name>trainname>
		<pose>Unspecifiedpose>
		<truncated>1truncated>
		<difficult>0difficult>
		<bndbox>
			<xmin>1xmin>
			<ymin>36ymin>
			<xmax>235xmax>
			<ymax>299ymax>
		bndbox>
	object>
annotation>

其中利用parse_xml_to_dict函数对xml将xml文件解析为字典dict
参考tensorflow官方文档的recursive_parse_xml_to_dict

 def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

如果transforms is not None,则需要对图片image进行变换,同样图片的信息target也会跟着变换。对应代码在transforms.py文件中,主要调用了ToTensorRandomHorizontalFlip类,代码如下:

import random
from torchvision.transforms import functional as F


class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
        return image, target

其中ToTensor类比较简单,直接用pytorch提供的方法进行转换。我这边主要讲下随机翻转RandomHorizontalFlip

在正向传播过程中,会输入image和图片信息target.并随机概率如果小于设置的概率(如0.5)则对图像进行翻转。其中翻转后,bounding box的坐标信息也需要跟着翻转,变换如下:

 # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox

通过图片解释下怎样进行变换的。
目标检测——自定义DataSet类_第4张图片

水平翻转即图片在水平方向进行调换, 经过翻转后原始图片左上角的 ( X m i n , Y m i n ) (X_{min},Y_{min}) (Xmin,Ymin)就跑到它的右上角去了,原始图片右上角的 ( X m a x , Y m a x ) (X_{max},Y_{max}) (Xmax,Ymax)就跑到它的左下角这个地方。
对于翻转之后新的左上角点 ( X 左 上 ′ , Y 左 上 ′ ) (X_{左上}',Y_{左上}') (X,Y),它的 Y 左 上 ′ Y_{左上}' Y值和变换前是一样的,变的只是 X 左 上 ′ X_{左上}' X,此时 X 左 上 ′ X_{左上}' X=width-X_{max}。

同理右下角左上角点 ( X 右 下 ′ , Y 右 下 ′ ) (X_{右下}',Y_{右下}') (X,Y),它的 Y 右 下 ′ Y_{右下}' Y值和变换前是一样的,变的只是 X 右 下 ′ X_{右下}' X,此时 X 右 下 ′ X_{右下}' X=width-X_{min}。
因此有如下的变换代码

 bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息

预处理之后就可以返回图像image 和图像信息target

2.3 get_height_and_width方法

 def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

这个比较简单,解析后直接获取heightwidth

以上就完成了自定义dataset类编写,完整的代码如下:

import numpy as np
from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree


class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""

    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        # 增加容错能力
        if "VOCdevkit" in voc_root:
            self.root = os.path.join(voc_root, f"VOC{year}")
        else:
            self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # read train.txt or val.txt file
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

        self.xml_list = []
        # check file
        for xml_path in xml_list:
            if os.path.exists(xml_path) is False:
                print(f"Warning: not found '{xml_path}', skip this annotation file.")
                continue

            # check for targets
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            if "object" not in data:
                print(f"INFO: no objects in {xml_path}, skip this annotation file.")
                continue

            self.xml_list.append(xml_path)

        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)

        # read class_indict
        json_file = './pascal_voc_classes.json'
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transforms

    def __len__(self):
        return len(self.xml_list)

    def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []
        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

你可能感兴趣的:(object,detection,目标检测,pytorch,深度学习)