pytorch导入自定义数据集

最近刚学图神经网络,数据集导入折腾了很久,终于开窍了一点。
目前常用的数据导入方法主要有两种:

(1)torchvision自带的导入方式:
这种导入方式使用了torchvision自带的库,打开函数进去看它的说明是这样的:
pytorch导入自定义数据集_第1张图片
直接翻译过来意思就是图片要放在相应类别的文件夹下,文件夹名字就是图片所属的类别。

导入代码如下:

from torchvision import datasets
'''transform可自行定义'''
train_transforms = transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
train_dataset=datasets.ImageFolder(train_dir,transform=train_transforms)

2.自定义数据导入方式
现实使用过程中经常会遇到图片跟标签是分开放置的情况,如下面两张图所示,图片和label分别放置的,那么torchvision自带的库就不能用了,需要自定义数据读取方式。
pytorch导入自定义数据集_第2张图片
pytorch导入自定义数据集_第3张图片
首先用os库遍历文件,提取图片的名字和对应的label,保存在CSV文件中(当然完整的程序不保存也可以,这里是为了方便后面用),遍历的方式参考这篇博客。

开始自定义导入数据的类,这部分的格式都是统一的,最开始先写上这几个必须的函数,再往里面填东西:

from torch.utils.data import Dataset#Dataset是必须要继承的
class LoadData(Dataset):
	def __init__(self,image_path,transform=None):
		#初始化,读取数据集
	def __getitem__(self,index):
		#对于指定id,获取该数据并返回
	def __len__(self):
		#获取数据及总大小

确定模板以后直接往里面填东西就可以了:

from torch.utils.data import Dataset#Dataset是必须要继承的
import pandas as pd
from PIL import Image
class LoadData(Dataset):
	def __init__(self,image_path,transform=None):
		self.imgs_info=pd.read_csv(image_path)
	def __getitem__(self,index):
		img_path,label=self.imgs_info['img_path'],self.imgs_info['weather']
		img=Image.open(img_path)#得到路径需要打开图片
		img=img.convert('RGB')#将图片转为张量
		if transform is not None:
			img=transform(img)#图像变换
		returnimg,label
	def __len__(self):
		return len(self.imgs_info)		

主函数中调用:

from torchvision import transforms
train_csv_path=r'./dataset/train.csv'
train_transforms=transforms.Compose(
        [transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
train_dataset=LoadData(train_csv_path,transform=train_transforms)
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=True)

你可能感兴趣的:(pytorch,pytorch,深度学习,人工智能)