在网上有很多可用的公开的数据集,根据自己的需要,下载相应的数据集,可以用来训练网络,测试网络模型的精度。
[数据集转载来源] 深度学习中的遥感影像数据集
Pascal VOC网址:http://host.robots.ox.ac.uk/pascal/VOC/
转载的一篇包含了比较多的数据集的一篇博文,可以参考一下。
但有些时候,我们需要根据我们自己的需求,根据自己的研究方向和类型,设置自己的数据集,以下,简单的阐述了设置数据集的一些步骤。
在pytorch中,官方文档简单的介绍了创建数据集的简单步骤。
# ================================================================== #
# 5. Input pipeline for custom dataset #
# ================================================================== #
# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# TODO
# 1. Initialize file paths or a list of file names.
# 设置文件和标签的路径,或者文件名list,最关键的就是设置好数据集的路径,以及初始化一些数据集的属性
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
# 通过上述的数据集路径,读取文件,并且对文件进行预处理操作,返回真实的文件数据,比如image and label
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
# 比较简单,只是设置数据集的长度,返回一个值
return 0
# You can then use the prebuilt data loader.
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
batch_size=64,
shuffle=True)
所以说,最关键的就是初始化文件路径和读取文件,以及文件的预处理。
其他的一些需要用到的属性和方法,在需要的时候加上就行。比如如何进行数据读取、如何进行预处理等。
在实际应用中,创建数据集的基本步骤也大致如此,只需要把相应的方法写全即可,下面以目标检测的数据集为例。
1.数据准备
首先,我们拿到目标检测的遥感图像,放到一个总的文件夹中。再使用标签工具labelImg进行标注,将标注好的xml标签文件同样放到同一个标签文件夹中。(下图仅为部分数据的截图)
这里有个小问题,就是使用不同的标注工具,得到的bonding box的格式会有不同,在后期读取的时候,可能会报错。
以下是图像和标签数据的截图实例:
再创建一个类别文件,设置不同的分类的地物名称,以及一个类别对应的JSON文件,不同类别对应不同的key和value。
将上述文件都放在同一个文件夹中,再将这些数据随机分成训练集和测试集,代码如下。
import os
import random
def train_val_txt(files_path,val_rate,output_train_path,output_val_path):
'''
:param files_path: 保存的所有图片文件的目录
:param val_rate: 选择测试集相对于总体的比率
:param output_train_path: 输出的train的filename的txt目录
:param output_val_path: 输出的val的filename的txt目录
'''
if not os.path.exists(files_path):
print("文件夹不存在")
exit(1)
# 获取文件目录下的所有文件名,返回列表格式
files_name = sorted([file.split('.')[0] for file in os.listdir(files_path)])
files_num = len(files_name)
# 设置采样的序号,从[0,files_num] 中随机抽取k个数
val_index = random.sample(range(0, files_num), k=int(files_num * val_rate))
train_files = []
val_files = []
for index, file_name in enumerate(files_name):
if index in val_index:
val_files.append(file_name)
else:
train_files.append(file_name)
try:
with open(output_train_path,'x') as f:
f.write('\n'.join(train_files))
with open(output_val_path, 'x') as f:
f.write('\n'.join(val_files))
except Exception as e:
print(e)
exit(1)
根据注释,设置路径和分类比,运行后可以得到train.txt和val.txt文件。
文本文件中保存着训练集或测试集的样本名称,在后续操作中,直接读取不同的样本名称,就可以加载不同的数据。
最终效果如下:
这样数据就准备好了。
2.设置数据集
按照官方文档的框架,自定义数据集。
在init中,主要是初始化用户数据集的目录,包括设置标签目录,遥感影像目录,以及预处理。
def __init__(self, data_root, transforms, train=True):
#设置不同的路径,分别设置成图片路径和标签路径
self.root = os.path.join(data_root, "data")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
"""读取训练集/测试集,txt_list是路径"""
if train:
txt_list = os.path.join(self.root, "ImageSets", "Main", "train_1.txt")
else:
txt_list = os.path.join(self.root, "ImageSets", "Main", "val_1.txt")
with open(txt_list) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines()]
# 读取分类索引
try:
json_file = open('./data/classes.json', 'r')
self.class_dict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# 定义预处理方式
self.transforms = transforms
len方法主要是返回数据集的个数,即有多少张图像(图像和标签是对应的)。该方法比较简单,直接返回即可。
def __len__(self):
"""返回训练集/测试集中图片的个数"""
return len(self.xml_list)
在getitem中,传入index,即对不同index的图像和标签进行处理,返回一个image和target(包含boxes、label、image_id等信息)。
对于不同的需求,设置不同的方法,这里只是以目标检测为例,故需要返回image、label和boxes边界框等信息。
def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx] # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str
with open(xml_path) as fid:
xml_str = fid.read()
# xml = etree.fromstring(xml_str)
xml = etree.fromstring(xml_str.encode('utf-8')) # 读取xml文件的内容
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"]) # 从xml文件中得到img文件路径
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = [] # 是否难检测,crowd为0表示单目标
for obj in data["object"]:
"""得到训练集边框坐标,分类和难易程度"""
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx]) # 当前数据对应的索引值
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # 框的面积:长*宽
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
除了以上三个方法外,我们可以根据自己的需求,增加不同的方法,在数据集设置阶段对数据的处理上,会比之后计算得出的要快一些。
这里增加一个标签索引值的处理方法。
#官方的方法:将标签的索引值存储为字典
def parse_xml_to_dict(self, xml):
if len(xml) == 0: # 说明已经遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归 遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
最终的显示效果如下:
# read class_indict
category_index = {}
try:
json_file = open('./data/classes.json', 'r')
class_dict = json.load(json_file)
category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
print(e)
exit(-1)
data_transform = {
"train": transforms.Compose([transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)]),
"val": transforms.Compose([transforms.ToTensor()])
}
# load train data set
train_data_set = SelfDataSet(os.getcwd(), data_transform["train"], True)
print(len(train_data_set))
测试后,可以加载出图像和train_data_set,即数据集创建成功。
完整的代码示例:
这个案例是以Skysat数据为例设置的数据集,只需要修改图像和标签的路径即可。
from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree
#设置数据集
class SelfDataSet(Dataset):
# 根目录,预处理方式,训练集/验证集
def __init__(self, data_root, transforms, train=True):
#设置不同的路径,分别设置成图片路径和标签路径
self.root = os.path.join(data_root, "SkysatData")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
"""读取训练集/测试集,txt_list是路径"""
if train:
txt_list = os.path.join(self.root, "ImageSets", "Main", "train.txt")
else:
txt_list = os.path.join(self.root, "ImageSets", "Main", "val.txt")
with open(txt_list) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
for line in read.readlines()]
# 读取分类索引
try:
json_file = open('./SkysatData/classex.json', 'r')
self.class_dict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# 定义预处理方式
self.transforms = transforms
def __len__(self):
"""返回训练集/测试集中图片的个数"""
return len(self.xml_list)
def __getitem__(self, idx):
# read xml
xml_path = self.xml_list[idx] # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str
with open(xml_path) as fid:
xml_str = fid.read()
# xml = etree.fromstring(xml_str)
xml = etree.fromstring(xml_str.encode('utf-8')) # 读取xml文件的内容
data = self.parse_xml_to_dict(xml)["annotation"]
img_path = os.path.join(self.img_root, data["filename"]) # 从xml文件中得到img文件路径
image = Image.open(img_path)
if image.format != "JPEG":
raise ValueError("Image format not JPEG")
boxes = []
labels = []
iscrowd = [] # 是否难检测,crowd为0表示单目标
for obj in data["object"]:
"""得到训练集边框坐标,分类和难易程度"""
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
boxes.append([xmin, ymin, xmax, ymax])
labels.append(self.class_dict[obj["name"]])
iscrowd.append(int(obj["difficult"]))
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx]) # 当前数据对应的索引值
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # 框的面积:长*宽
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def get_height_and_width(self, idx):
# read xml,每个xml
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = self.parse_xml_to_dict(xml)["annotation"]
data_height = int(data["size"]["height"])
data_width = int(data["size"]["width"])
return data_height, data_width
#官方的方法:将标签的索引值存储为字典
def parse_xml_to_dict(self, xml):
if len(xml) == 0: # 说明已经遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归 遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
@staticmethod
def collate_fn(batch):
return tuple(zip(*batch))
import transforms
from draw_box_utils import draw_box
from PIL import Image
import json
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import random
# read class_indict
category_index = {}
try:
json_file = open('./SkysatData/classex.json', 'r')
class_dict = json.load(json_file)
category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
print(e)
exit(-1)
data_transform = {
"train": transforms.Compose([transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)]),
"val": transforms.Compose([transforms.ToTensor()])
}
# load train data set
train_data_set = SelfDataSet(os.getcwd(), data_transform["train"], True)
print(len(train_data_set))
# index = 40
for index in random.sample(range(0, len(train_data_set)), k=5):
img, target = train_data_set[index]
img = ts.ToPILImage()(img)
draw_box(img,
target["boxes"].numpy(),
target["labels"].numpy(),
[1 for i in range(len(target["labels"].numpy()))],
category_index,
thresh=0.5,
line_thickness=1)
Image._show(img)
效果如下:
直接上代码了,根据YOLOv5的数据集设置,提取核心的数据集设置代码,代码如下:
import glob
import os
from pathlib import Path
import cv2
import numpy as np
import torch
class SkysatDataset(torch.utils.data.Dataset):
# 设置基本的文件路径
def __init__(self, path, imgsz, prefix=''):
self.path = path
self.imgsz = imgsz
# set the file path
try:
f = []
for p in path if isinstance(path, list) else [path]:
p = Path(p)
if p.is_dir():
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
self.img_files = sorted([x.replace('/', os.sep) for x in f])
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}')
self.label_files = img2label_paths(self.img_files) # labels
self.n = len(self.img_files)
def __len__(self):
return self.n
# 通过getitem获得img和label
def __getitem__(self, index):
img_path, label_path = self.img_files[index], self.label_files[index]
img = cv2.imread(img_path)
label = []
with open(label_path, 'r') as f:
for each in f.readlines():
cls, x, y, w, h = each.replace('\n', '').split(' ')
label.append([cls,x,y,w,h])
label = np.array(label).astype(np.float32)
label = xywh2xyxy(label[:,1:5])*self.imgsz
return img, label
# 通过img路径得到label路径
def img2label_paths(img_paths):
# Define label paths as a function of image paths
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
# 可视化,将坐标改变格式
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
# 可视化操作
def vis(img, boxes):
for i in range(len(boxes)):
box = boxes[i]
x0 = int(box[0])
y0 = int(box[1])
x1 = int(box[2])
y1 = int(box[3])
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 255, 0), 1)
return img
if __name__ == '__main__':
dataset = SkysatDataset(path=r'D:\DATA\Models\customize\YOLOv5-6.0-St\dataset\skysat\images\train', imgsz=512)
img, label = dataset[2]
img = vis(img, label)
cv2.imshow('img', img)
cv2.waitKey(0)
cv2.destroyWindow()
图片文件和标签存放格式如下:
label的存储格式如下:(cls, x, y, w, h)并且对x, y, w, h进行了归一化处理
按照这个方式存放文件,可以得到如下的效果图:
这个核心代码比较简洁,可以直接使用制作自定义数据集。
本文主要为读书笔记,根据学习资料中的案例,使用自己的例子进行数据集创建,读者仅作参考,如有错误或补充,还请评论批评指正,谢谢!
当然,这只是自定义的一种方式,一般的Github都会有自己的数据集设置方式,按照项目中的修改即可。