Faster-RCNN系列 二(Datasets代码,python)

Faster-RCNN系列 二(Datasets代码,python)
数据集部分
定义数据集部分是继承Dataset,自己主要实现__len__和__getitem__模块;
定义__init__模块,主要目的是定义数据地址,将数据初始化,__init__模块主要包含图像,xml文件,train_txt和test_txt文件,分类的json文件的路径

def __init__(self, voc_root, transforms, txt_name: str = "train.txt"):
    self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
    self.img_root = os.path.join(self.root, "JPEGImages")
    self.annotations_root = os.path.join(self.root, "Annotations")

    # read train.txt or val.txt file
    txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
    assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

    with open(txt_path) as read:
        self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                         for line in read.readlines()]

    # check file
    assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
    for xml_path in self.xml_list:
        assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)

    # read class_indict
    try:
        json_file = open('./pascal_voc_classes.json', 'r')
        self.class_dict = json.load(json_file)
    except Exception as e:
        print(e)
        exit(-1)

    self.transforms = transforms

定义__len__函数,主要目的是返回数据长度,xml文件包含了全部的数据,直接返回长度就好了

  def  __len__(self):
        return len(self.xml_list)

定义__getitem__模块,主要将图像数据与标注好的target信息返回,此处需要实现parse_xml_to_dic方法,主要目的是将xml文件中数据以字典形式存储起来
将单个图片信息xml文件拿出来,以etree的形式去读取xml文件中的所有数据

def __getitem__(self, idx):
    # read xml
    xml_path = self.xml_list[idx]
    with open(xml_path) as fid:
        xml_str = fid.read()
    xml = etree.fromstring(xml_str)
    data = self.parse_xml_to_dict(xml)["annotation"]
    img_path = os.path.join(self.img_root, data["filename"])
    image = Image.open(img_path)
    if image.format != "JPEG":
        raise ValueError("Image format not JPEG")
    boxes = []
    labels = []
    iscrowd = []
    for obj in data["object"]:
        xmin = float(obj["bndbox"]["xmin"])
        xmax = float(obj["bndbox"]["xmax"])
        ymin = float(obj["bndbox"]["ymin"])
        ymax = float(obj["bndbox"]["ymax"])
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(self.class_dict[obj["name"]])
        iscrowd.append(int(obj["difficult"]))
    # convert everything into a torch.Tensor
    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    labels = torch.as_tensor(labels, dtype=torch.int64)
    iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
    image_id = torch.tensor([idx])
    area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

    target = {}
    target["boxes"] = boxes
    target["labels"] = labels
    target["image_id"] = image_id
    target["area"] = area
    target["iscrowd"] = iscrowd

    if self.transforms is not None:
        image, target = self.transforms(image, target)

    return image, target

实现parse_xml_to_dic方法,将xml文件的信息以字典的形式返回

def parse_xml_to_dict(self, xml):
    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        return {xml.tag: xml.text}

    result = {}
    for child in xml:
        child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}
from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree

class VOC2012DataSet(Dataset):
    def __int__(self, voc_root, transforms, txt_name:str = "train.txt"):
        self.root = os.path.join(voc_root,"VOCdevkit", "VOC2012")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotation_root = os.path.join(self.root, "Annotations")

        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
        with open(txt_path) as read:
            self.xml_list = [os.path.join(self.annotation_root, line.split() +'.xml')
                             for line in read.readlines()]
        assert len(self.xml_list) > 0,"in '{}' file does not found any information.".format(txt_path)
        for xml_path in self.xml_list:
            assert os.path.exists(xml_path), "not found '{}' file".format(xml_path)
        try:
            json_file = open('./pascal_voc_class.json', 'r')
            self.class_dict = json.load(json_file)
        except Exception as e:
            print(e)
            exit(-1)
        self.transforms = transforms
class VOC2012DataSet(Dataset):
    """读取解析PASCAL VOC2012数据集"""

    def __init__(self, voc_root, transforms, txt_name: str = "train.txt"):
        self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # read train.txt or val.txt file
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                             for line in read.readlines()]

        # check file
        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
        for xml_path in self.xml_list:
            assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)

        # read class_indict
        try:
            json_file = open('./pascal_voc_classes.json', 'r')
            self.class_dict = json.load(json_file)
        except Exception as e:
            print(e)
            exit(-1)

        self.transforms = transforms

    def  __len__(self):
        return len(self.xml_list)


    def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image format not JPEG")
        boxes = []
        labels = []
        iscrowd = []
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            iscrowd.append(int(obj["difficult"]))
        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width



    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

你可能感兴趣的:(深度学习,python)