# 根据坐标把框画到图上
import xml.etree.ElementTree as ET
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import os
def read_xml(xml_path):
tree = ET.parse(xml_path)
root = tree.getroot()
boxes = []
for obj in root.findall('object'):
bbox = obj.find('bndbox')
xmin = int(bbox.find('xmin').text)
ymin = int(bbox.find('ymin').text)
xmax = int(bbox.find('xmax').text)
ymax = int(bbox.find('ymax').text)
# Read class label
class_label = obj.find('name').text
boxes.append((xmin, ymin, xmax, ymax, class_label))
return boxes
def visualize_boxes(image_path, boxes):
# Read the image using OpenCV
image = cv2.imread(image_path)
# Convert BGR image to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
ax.imshow(image_rgb)
# Add bounding boxes to the image
for box in boxes:
xmin, ymin, xmax, ymax, class_label = box
rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor='g', facecolor='none')
ax.add_patch(rect)
# Display class label
plt.text(xmin, ymin, class_label, color='r', fontsize=8, bbox=dict(facecolor='white', alpha=0.7))
# Set the title as the file name
plt.title(os.path.splitext(os.path.basename(image_path))[0])
# Show the plot
plt.show()
if __name__ == "__main__":
xml_folder = r"D:\work\data\insects\train\annotations\xmls"
image_folder = r"D:\work\data\insects\train\images"
# Specify the file name of the image you want to visualize
image_file_name = "1.jpeg"
xml_file = os.path.join(xml_folder, os.path.splitext(image_file_name)[0] + ".xml")
image_path = os.path.join(image_folder, image_file_name)
boxes = read_xml(xml_file)
visualize_boxes(image_path, boxes)
继承torch.utils.Dataset类来读取数据集,在getitem函数中返回图片、框坐标、框类别,主要分为以下步骤:
定义数据集的路径、类别
DATA_ROOT = r'D:\work\data\insects'
CATEGORY_NAMES = ['Boerner', 'Leconte', 'Linnaeus',
'acuminatus', 'armandi', 'coleoptera', 'linnaeus']
# 根据类名返回对应的id
def get_insect_names():
insect_category2id = {}
for i, item in enumerate(CATEGORY_NAMES):
insect_category2id[item] = i
return insect_category2id
CATEGORY_NAME_ID = get_insect_names()
NUM_CLASSES = len(CATEGORY_NAMES)
解析xml文件,获取框的位置、类别
框坐标从xyxy改成了xywh
import xml.etree.ElementTree as ET
import os
import numpy as np
def read_xml(xml_path):
"""
解析xml文件,返回坐标和类别信息
:param xml_path:
:return:
"""
tree = ET.parse(xml_path)
root = tree.getroot()
fname = os.path.basename(xml_path).split()[0]
objs = tree.findall('object')
# 存框坐标和类别
gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
gt_class = np.zeros((len(objs),), dtype=np.int32)
difficult = np.zeros((len(objs),), dtype=np.int32)
for i, obj in enumerate(root.findall('object')):
bbox = obj.find('bndbox')
xmin = int(bbox.find('xmin').text)
ymin = int(bbox.find('ymin').text)
xmax = int(bbox.find('xmax').text)
ymax = int(bbox.find('ymax').text)
_difficult = int(obj.find('difficult').text)
cname = obj.find('name').text
# 直接改成 xywh格式
gt_bbox[i] = [(xmin + xmax) / 2.0, (ymin + ymax) / 2.0, ymax - ymin + 1., ymax - ymin + 1.]
gt_class[i] = CATEGORY_NAME_ID[cname]
difficult[i] = _difficult
record = {
'fname': fname,
'gt_bbox': gt_bbox,
'gt_class': gt_class,
'difficult': difficult
}
return record
继承torch.nn.Dataset,定义InsectDataset类,包含 init/getitem/len和get_annotations四个方法
returns: image, gt_boxes, labels
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
class InsectDataset(Dataset):
"""
:returns img, gt_boxes, labels
img: tensor
gt_boxes: list 框的相对位置
labels: list 框的标签
"""
def __init__(self, datadir, mode='train', transforms=None):
super(InsectDataset, self).__init__()
self.datadir = os.path.join(datadir, mode)
self.records = self.get_annotations()
self.transforms = transforms
def __getitem__(self, idx):
record = self.records[idx]
gt_boxes = record['gt_bbox']
labels = record['gt_class']
image = np.array(Image.open(record['im_file']))
w = image.shape[0]
h = image.shape[1]
# gt_bbox 用相对值
gt_boxes[:, 0] = gt_boxes[:, 0] / float(w)
gt_boxes[:, 1] = gt_boxes[:, 1] / float(h)
gt_boxes[:, 2] = gt_boxes[:, 2] / float(w)
gt_boxes[:, 3] = gt_boxes[:, 3] / float(h)
if self.transforms:
transformed = self.transforms(image=image, bboxes=gt_boxes, class_labels=labels)
image = transformed['image']
gt_boxes = np.array(transformed['bboxes'])
labels = np.array(transformed['class_labels'])
image = image.transpose((2,1,0)) # h,w,c -> c,w,h
return image, gt_boxes, labels
def __len__(self):
return len(self.records)
def get_annotations(self):
"""
从xml目录下面读取所有文件的标注信息
:param cname2cid:
:param datadir:
:return: record:[{im_file: array
gt_boxes: array
gt_classes: array
difficult: array}]
"""
datadir = self.datadir
filenames = os.listdir(os.path.join(datadir, 'annotations', 'xmls'))
records = []
for fname in filenames:
# 拿到文件名
fid = fname.split('.')[0]
fpath = os.path.join(datadir, 'annotations', 'xmls', fname)
img_file = os.path.join(datadir, 'images', fid + '.jpeg')
# 解析xml文件
record = read_xml(fpath)
record['im_file'] = img_file # 把图片路径加上
records.append(record)
return records
这里采用albumentations进行数据增强,参考官网的目标检测数据增强教程即可,这里加入normalize、resize以及一些常见的数据增强策略,后续完善
import albumentations as A
transforms = A.Compose([
# A.RandomCrop(width=450, height=450),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),
A.Resize(width=640, height=640),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
在调用的时候注意框坐标的format,这里统一用yolo格式(xywh相对坐标)
if self.transforms:
transformed = self.transforms(image=image, bboxes=gt_boxes, class_labels=labels)
image = transformed['image']
gt_boxes = np.array(transformed['bboxes'])
labels = np.array(transformed['class_labels'])
由于不同图片的框数量不同,在用dataloader加载数据的时候,getitem的返回值shape不同会报错,因此用一个list包裹起来
def dataset_collate(batch):
"""
用list包一下 img, bboxes, labels
:param batch:
:return:
"""
images = []
bboxes = []
labels = []
for img, box, label in batch:
images.append(img)
bboxes.append(box)
labels.append(label)
images = torch.tensor(np.array(images))
return images, bboxes, labels
if __name__ == '__main__':
dataset = InsectDataset(DATA_ROOT, transforms=transforms)
print(dataset.__len__())
print('image_shape: ', dataset.__getitem__(1)[0].shape)
batch_size = 4
print()
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0,collate_fn=dataset_collate)
for inputs in train_loader:
print('img_shape:', inputs[0].shape)
print('gt_boxes:', inputs[1])
print('gt_labels:', inputs[2])
读取voc格式的数据集主要以下三个点需要注意一下
把画框的代码单独放在一个文件里,但其中read_xml的方法跟dataset中类似,框架搭好之后进一步优化一下
如果是anchor base的模型后续还需要根据锚框来处理得到每个锚框的objectness和坐标