处理VOC格式的数据集

VOC格式是常用的分割数据集,主要存储在xml文件中
torch也实现了读取Voc文件的函数datasets.VOCDetection
Voc格式 是需要在 打标签的文件夹中: cmd输入 labelImg [图像路径] [标签路径(txt文件)]

代码 全部参考b战的霹雳吧啦Wz导师

1.初始化

class VOCDataSet(Dataset):
    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        # 增加容错能力
        if "VOCdevkit" in voc_root:
            self.root = os.path.join(voc_root, f"VOC{year}")
        else:
            self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        self.img_root = os.path.join(self.root, "JPEGImages") #JPEGImages是所有图片的目录
        self.annotations_root = os.path.join(self.root, "Annotations")

        # 自定义的话txt_name的文件可以通过split_data.py生成
        #读取train.txt文件
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)


        with open(txt_path) as read:
            #line.strip()去掉换行符  形成 train训练的 xml_list文件路径
            xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

        self.xml_list = []

        # 解析xml文件  并且过滤错误的xml文件路径
        for xml_path in xml_list:
            if os.path.exists(xml_path) is False:
                print(f"Warning: not found '{xml_path}', skip this annotation file.")
                continue

            with open(xml_path) as fid:
                xml_str = fid.read()

            xml = etree.fromstring(xml_str)#xml文件转成字符串读取

            #解析xml文件返回成字典
            data = self.parse_xml_to_dict(xml)["annotation"]

            if "object" not in data:
                print("INFO: no objects in {0}, skip this annotation file.".format(xml_path))
                continue

            self.xml_list.append(xml_path)


        # 读取类别文件的路径
        json_file = './pascal_voc_classes.json'

        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transfor

2.解析xml 函数parse_xml_to_dict

    def parse_xml_to_dict(self, xml):
        """
        将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
        Args:
            xml: xml tree obtained by parsing XML file contents using lxml.etree

        Returns:
            Python dictionary holding XML contents.
        """

        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]

            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

3.getitem(返回idx指定的图像 和 target)

    def __getitem__(self, idx):
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        # 使用etree解析xml文件该方法是将xml格式转化为Element 对象,
        # Element 对象代表 XML 文档中的一个元素。元素可以包含属性、其他元素或文本。如果一个元素包含文本,则在文本节点中表示该文本
        xml = etree.fromstring(xml_str)

        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)

        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []

        #遍历多个object 可能有多个分割的物体
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue

            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # 把列表转成tensor torch.as_tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        
        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target#taget 是一个字典 里面含有 坐标,标签,面积等信息

4.collate_fn函数处理

分类网络返回的是图片和tensor 目标检测的话是图片和tagert一起组成的元组
因为taget是一个字典 里面含有 坐标,标签,面积等信息,所以不能单独的stack 要先

    def collate_fn(batch):
        return tuple(zip(*batch))

5.验证 dataset

try:
    json_file = open('./pascal_voc_classes.json', 'r')
    class_dict = json.load(json_file)
    category_index = {str(v): str(k) for k, v in class_dict.items()}
except Exception as e:
    print(e)
    exit(-1)

data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(0.5)]),
    "val": transforms.Compose([transforms.ToTensor()])
}

# load train data set
train_data_set = VOCDataSet(os.getcwd(), "2012", data_transform["train"], "train.txt")
print(len(train_data_set))

for index in random.sample(range(0, len(train_data_set)), k=5):
    img, target = train_data_set[index]
    img = t.ToPILImage()(img)#等同于tt = t.ToPILImage(),tt(img)
    # img = ts.ToPILImage()(img)

    plot_img = draw_objs(img,
                         target["boxes"].numpy(),
                         target["labels"].numpy(),
                         np.ones(target["labels"].shape[0]),
                         category_index=category_index,
                         box_thresh=0.5,
                         line_thickness=3,
                         font='arial.ttf',
                         font_size=20)
    plt.imshow(plot_img)
    plt.show()

6.边界框的绘画

def draw_objs(image: Image,
              boxes: np.ndarray = None,
              classes: np.ndarray = None,
              scores: np.ndarray = None,
              masks: np.ndarray = None,
              category_index: dict = None,
              box_thresh: float = 0.1,
              mask_thresh: float = 0.5,
              line_thickness: int = 8,
              font: str = 'arial.ttf',
              font_size: int = 24,
              draw_boxes_on_image: bool = True,
              draw_masks_on_image: bool = False):
    """
    将目标边界框信息,类别信息,mask信息绘制在图片上
    Args:
        image: 需要绘制的图片
        boxes: 目标边界框信息
        classes: 目标类别信息
        scores: 目标概率信息
        masks: 目标mask信息
        category_index: 类别与名称字典
        box_thresh: 过滤的概率阈值
        mask_thresh:
        line_thickness: 边界框宽度
        font: 字体类型
        font_size: 字体大小
        draw_boxes_on_image:
        draw_masks_on_image:

    Returns:

    """

    # 过滤掉低概率的目标
    idxs = np.greater(scores, box_thresh)
    boxes = boxes[idxs]
    classes = classes[idxs]
    scores = scores[idxs]
    if masks is not None:
        masks = masks[idxs]
    if len(boxes) == 0:
        return image

    colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]

    if draw_boxes_on_image:
        # Draw all boxes onto image.
        draw = ImageDraw.Draw(image)
        for box, cls, score, color in zip(boxes, classes, scores, colors):
            left, top, right, bottom = box
            # 绘制目标边界框
            draw.line([(left, top), (left, bottom), (right, bottom),
                       (right, top), (left, top)], width=line_thickness, fill=color)
            # 绘制类别和概率信息
            draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)

    if draw_masks_on_image and (masks is not None):
        # Draw all mask onto image.
        image = draw_masks(image, masks, colors, mask_thresh)

    return image

你可能感兴趣的:(框架,python,计算机视觉,人工智能)