【目标检测】自定义Dataset方法(VOC数据集)——pytorch实现

推荐参考:TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL
以VOC2007数据集为例:

import os
import torch
from torch.utils.data import Dataset
from PIL import Image

from lxml import etree
from xml.etree import ElementTree

"""
VOC2007数据集格式:
└──VOCdevkit
    └──VOC2007
        └──JPEGImages
            └──0.jpg
            └──1.jpg
            └──2.jpg
            └──...
        └──Annotations
            └──0.xml
            └──1.xml
            └──2.xml
            └──...
        └──ImageSets
            └──Main
                └──train.txt
                └──val.txt
                └──trainval.txt
                └──test.txt
      
"""

'''
xml文件信息(例):

    JPEGImages
    0.jpg
    X:/.../.../VOCdevkit/VOC2007/JPEGImages/0.jpg
    
        Unknown
    
    
        600
        800
        3
    
    0
    
        peach
        Unspecified
        0
        0
        
            100
            250
            150
            260
        
    
    
        cat
        Unspecified
        0
        0
        
            200
            50
            550
            370
        
    

'''


class CustomDataset(Dataset):  # 自定义数据集
    def __init__(self, root, transforms=None, dataset_property="train"):  # 初始化方法
        self.root = root  # 数据路径,应指向".../.../VOCdevkit"
        self.transforms = transforms  # 预处理方法,一般来说需要传入,注意区分训练数据和验证数据的预处理方法
        self.images_dir = os.path.join(self.root, "VOC2007/JPEGImages")  # 图像文件存储路径,默认VOC2007/JPEGImages
        self.annotations_dir = os.path.join(self.root, "VOC2007/Annotations")  # 标注文件存储路径,默认VOC2007/Annotations
        self.imagesets_dir = os.path.join(self.root, "VOC2007/ImageSets/Main")  # 数据集划分文件存储路径,VOC2007/ImageSets/Main

        assert dataset_property in ["train", "val", "trainval", "test"]  # 数据集划分文件名应该为train/val/trainval/test,以txt形式存储
        with open(os.path.join(self.imagesets_dir, f"{dataset_property}.txt")) as f:  # 打开对应txt文件
            self.data_names = [i.strip() for i in f.readlines()]  # 读取数据,存放于列表(每张图片的名称,不包含.jpg后缀)

        self.label_dict = {  # 语义标签与int的对应关系,一般从1开始,0表示背景(一般通过读取json或txt文件来获得对应关系)
            "cat": 1,
            "dog": 2,
            "peach": 3
        }

    def __len__(self):  # 获取数据集长度方法  **该方法必须定义**
        return len(self.data_names)  # 返回数据列表的长度

    def __getitem__(self, index):  # 采样方法,根据索引index取得对应的图像image和标签target  **该方法必须定义**
        image_path = os.path.join(self.images_dir, self.data_names[index] + ".jpg")  # 图片路径,要求图片为jpg格式
        image = Image.open(image_path).convert("RGB")  # PIL图像

        annotation_path = os.path.join(self.annotations_dir, self.data_names[index] + ".xml")  # xml路径
        label_names, difficults, boxes, areas = self.read_xml(annotation_path)  # 解析xml文件,获取标注信息:标签名称、难识别目标标签、边界框、面积
        labels = [self.label_dict[f"{label}"] for label in label_names]  # 按照对应关系,将语义标签转化为int类型

        labels, difficults, boxes, areas = map(lambda t: torch.as_tensor(t),
                                               [labels, difficults, boxes, areas])  # 将标注信息转化为tensor形式

        target = {  # 构建最后返回的target,一般还包括image_id(图像id)、masks(分割掩码)、iscrowd(是否为多目标)
            "boxes": boxes,  # 边界框
            "labels": labels,  # 标签
            "area": areas,  # 面积
            "isdifficult": difficults  # 难识别目标标签
        }

        if self.transforms is not None:  # 如果使用预处理方法
            image, target = self.transforms(image, target)  # 进行image和target的预处理,transforms函数/类需要复写,以满足对target的变换

        return image, target  # 返回image,target

    def read_xml(self, annotation_path):  # 读取xml信息方法
        objnames = []  # 用于存放目标名称
        difficults = []  # 用于存放难识别目标标签
        objboxes = []  # 用于存放边界框
        objareas = []  # 用于存放面积

        parser = etree.XMLParser(encoding="utf-8")  # xml文件解析器
        xmlroot = ElementTree.parse(annotation_path, parser=parser).getroot()  # 解析xml文件并获得root节点
        for object in xmlroot.findall("object"):  # 寻找所有object节点并遍历
            objnames.append(object.find("name").text)  # 获取name节点数据,填入列表
            difficults.append(int(object.find("difficult").text))  # 获得difficult节点数据,转换为int,填入列表
            objxmin = float(object.find("bndbox/xmin").text)  # 获得bndbox/xmin节点数据,转换为float
            objymin = float(object.find("bndbox/ymin").text)  # 获得bndbox/ymin节点数据,转换为float
            objxmax = float(object.find("bndbox/xmax").text)  # 获得bndbox/xmax节点数据,转换为float
            objymax = float(object.find("bndbox/ymax").text)  # 获得bndbox/ymax节点数据,转换为float
            assert objxmax > objxmin and objymax > objymin  # 检查边界框的长宽是否为正
            objboxes.append([objxmin, objymin, objxmax, objymax])  # 边界框[xmin,ymin,xmax,ymax],填入列表
            objareas.append((objxmax - objxmin) * (objymax - objymin))  # 面积,填入列表

        return objnames, difficults, objboxes, objareas  # 返回目标名称、难识别目标标签、边界框、面积

你可能感兴趣的:(目标检测,pytorch,深度学习)