Dataset类似建立一个数组,建立数据集和数据标签之间的联系(就像数组下标和元素之间的联系)。
例如:CIFAR10是一个关于图片的数据,下面代码就是它的引入
data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
FaceDataset = datasets.ImageFolder('./data', transform=img_transform)
ImageFolder对文件夹类型的数据集进行引入,这里文件夹内部存储的数据集要求是同一类型的图片。
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
参数:
root:指定保持图片的文件夹路径
transform:对PIL Image进行的转换操作
target_transform:对label的转换
loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
from torch.utils.data import Dataset
class MyData(Dataset):
举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签
(如果有的话);len(dataset)则会返回这个数据集的容量。
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):#而且这里的self标签可以建立self.xxx达到xxx全局化的作用
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
from torch.utils import data
import numpy as np
from PIL import Image
class face_dataset(data.Dataset):
def __init__(self):
self.file_path = './data/faces/'
f=open("final_train_tag_dict.txt","r")
self.label_dict=eval(f.read())
f.close()
def __getitem__(self,index):
label = list(self.label_dict.values())[index-1]
img_id = list(self.label_dict.keys())[index-1]
img_path = self.file_path+str(img_id)+".jpg"
img = np.array(Image.open(img_path))
return img,label
def __len__(self):
return len(self.label_dict)
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path =os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)#URL的拼接
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir="练手数据集/val"
ant_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ant_label_dir)
bees_dataset=MyData(root_dir,bees_label_dir)
train_dataset = ants_dataset + bees_dataset #将两个数据集合并。
img,label=train_dataset[123]
img.show() #可以展示图片
这里的self在class类里面的作用就是def init(self,root_dir,label_dir)将初始化的root_dir和label_dir在class类内部进行公有化。
例如:接下来在__getitem__函数内部就直接进行使用了,img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)这里就体现了它的公有化。