无论是在作分类任务或者是目标检测任务都需要数据集的处理,一种是txt文件保存标签的信息,另一种只有图片如下图的形式,这一步也是学会faster-rcnn的关键点
分为训练和验证的照片
|
每个分类的类别
一种是猫的照片,另一种是狗的照片,这种是自己的数据集,其实官方的数据集也是这样放置的,比如CIFAR10,其中的是有10个文件夹,每个文件夹下是很多张一种数字的照片,正常情况下我们引进官方数据集的写法如下
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 在小型数据集上,通过随机水平翻转来实现数据增强
transforms.RandomGrayscale(), # 将图像以一定的概率转换为灰度图像
transforms.ToTensor(), # 数据集加载时,默认的图片格式是 numpy,所以通过 transforms 转换成 Tensor,图像范围[0, 255] -> [0.0,1.0]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5,
0.5))])
trainset = torchvision.datasets.CIFAR10('data/cifar-10', train=True,
download=True,
transform=transform)
testset = torchvision.datasets.CIFAR10('data/cifar-10', train=False,
download=True,
transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
shuffle=True,
num_workers=0) # 加载数据的时候使用几个子进程
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
shuffle=True,
num_workers=0)
重点是trainset和trainloder的形式,trainset就是把所有的照片全部加载进来了,trainloder是一种迭代的形式,比如 for i, data in enumerate(testloader):这样训练每一张照片的数据都有了,而且这样每张照片所对应每张照片的label,顺序没有乱训练数据的核心就是有照片和label,比如分类的时候,以刚才的照片为例,它其中的label是0,1是自动生成的,你之后需要在写一个字典,将0,1代表你自己的分类类别,很明显这里我的0,1就代表cat和dog了
在做目标检测的时候给的是xml文件,我们一样还是需要提取xml文件中框的坐标点和分类信息的,我个人更喜欢csv文件保存信息。现在这一步我就先教如何制作csv文件
import pandas as pd
import os
PATH = 'G:/trainshibie/55/val/'
xml = []
i =1
for (path, dirnames, filenames) in os.walk(PATH):
for filename in filenames:
Path = os.path.join(path, filename)
if i < 11:
value = (Path, 0)
xml.append(value)
else:
value = (Path, 1)
xml.append(value)
i = i + 1
column_name =['path','label']
xml = pd.DataFrame(xml,columns=column_name)
print(xml)
xml.to_csv('G:/trainshibie/55/ee.csv',index=None)
这一步简单,我们要做的就是保存每张照片的绝对路径和label,我的if条件是照片的数量只有10张,前10张代表cat label=0,后10张代表dog label=1。
import pandas as pd
import numpy as np
path = []
data = pd.read_csv('G:/trainshibie/55/ee.csv')
c=data.shape[0]
label=np.zeros(c,dtype=np.int32)
for index,row in data.iterrows():
path.append(row['path'])
label[index] = row['label']
这个记住就行,label=np.zeros(c,dtype=np.int32)这是创建方便
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
a = 'G:/trainshibie/55/val/'
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
dataset=ImageFolder(a,transform=transform)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=2,shuffle=True, num_workers=0)
基本都是先处理成dataset包含所有的照片,在处理成trainloader迭代器的形式,这种也是没有用csv文件保存label的,也就分类没有,目标检测和语义分割都是要现将信息从xml提取出来。保存到cvs文件中
前面我讲过csv文件的生成和读取,这一步就是利用csv文件中的数据做成迭代器的形式,它的主要步骤是写自己的dataset类
from PIL import Image
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
root = 'G:/trainshibie/55/val/'
# 定义读取文件的格式
def default_loader(path):
return Image.open(path).convert('RGB')
class Mydataset(Dataset):
def __init__(self,csv,transform=None, target_transform=None, loader=default_loader):
super(Mydataset,self).__init__()
self.path = [] #保存读取路径的
self.data = pd.read_csv(csv)
self.num = int(self.data.shape[0]) #一共多少照片
self.label = np.zeros(self.num, dtype=np.int32)
for index, row in self.data.iterrows():
self.path.append(row['path'])
self.label[index] = row['label'] #将数据全部读取出来
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
img = default_loader(self.path[index])
labels = self.label[index]
if self.transform is not None:
img = self.transform(img) #转化tensor类型
return img,labels
def __len__(self):
return len(self.data)
val_data=Mydataset(csv='G:/trainshibie/55/ee.csv', transform=transforms.ToTensor())
trainloader = DataLoader(val_data, batch_size=10,shuffle=True, num_workers=0)
第一步是先要学习ET的使用,专门是读取xml文件使用
import xml.etree.ElementTree as ET
tree = ET.parse('G:\picture\label_cat/000005.xml')
root = tree.getroot()
xml_list = []
for member in root.findall('object'):
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text))
xml_list.append(value)
print(xml_list)
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
def xml_to_csv(path):
xml_list = []
for xml_file in glob.glob(path + '/*.xml'):
print(xml_file)
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
try:
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text)
)
except ValueError:
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][1][0].text),
int(member[4][1][1].text),
int(member[4][1][2].text),
int(member[4][1][3].text)
)
xml_list.append(value)
column_name = ['filename','width','height','class','xmin','ymin','xmax','ymax']
xml_df = pd.DataFrame(xml_list,columns=column_name)
return xml_df
def main():
image_path = os.path.join(os.getcwd(),'E://pytorch/Annotations')
xml_df = xml_to_csv(image_path)
xml_df.to_csv('E://pytorch/labes.csv',index=None) #带路径 'E://'
print('finish')
main()
这一步可以拿去直接用,改一下路径就可以,专门针对xml文件的处理
import os.path as osp
import os
import pandas as pd
import numpy as np
data = pd.read_csv('E://pytorch/labes.csv')
c=data.shape[0]
boxes = np.zeros((c,4), dtype=np.uint16)
gt_classes=np.zeros(c,dtype=np.int32)
_classes = ('__background__', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
_class_to_ind = dict(zip(_classes, range(21)))
_ind_to_class = dict(zip(range(21), _classes))
for index,row in data.iterrows():
boxes[index,:]=[row['xmin'],row['ymin'],row['xmax'],row['ymax']]
gt_classes[index] = _class_to_ind[row['class']]
print(data.head(6))
print(boxes[1])
print(gt_classes[1])
print(_ind_to_class[gt_classes[1]])
全部看懂完数据处理这方面就差不多全部掌握了,可以应对各种数据的处理了,目标检测的数据增强部分后面后在单独写一篇,这一篇处理目标分类很好用。目标检测的数据集处理成csv的这种格式我写在了yolo3训练自己的数据集这篇文章了,可以去我的博客看下这个写法,基本跟这个一样,对比学习就行了。