最近刚学图神经网络,数据集导入折腾了很久,终于开窍了一点。
目前常用的数据导入方法主要有两种:
(1)torchvision自带的导入方式:
这种导入方式使用了torchvision自带的库,打开函数进去看它的说明是这样的:
直接翻译过来意思就是图片要放在相应类别的文件夹下,文件夹名字就是图片所属的类别。
导入代码如下:
from torchvision import datasets
'''transform可自行定义'''
train_transforms = transforms.Compose(
[transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
train_dataset=datasets.ImageFolder(train_dir,transform=train_transforms)
2.自定义数据导入方式
现实使用过程中经常会遇到图片跟标签是分开放置的情况,如下面两张图所示,图片和label分别放置的,那么torchvision自带的库就不能用了,需要自定义数据读取方式。
首先用os库遍历文件,提取图片的名字和对应的label,保存在CSV文件中(当然完整的程序不保存也可以,这里是为了方便后面用),遍历的方式参考这篇博客。
开始自定义导入数据的类,这部分的格式都是统一的,最开始先写上这几个必须的函数,再往里面填东西:
from torch.utils.data import Dataset#Dataset是必须要继承的
class LoadData(Dataset):
def __init__(self,image_path,transform=None):
#初始化,读取数据集
def __getitem__(self,index):
#对于指定id,获取该数据并返回
def __len__(self):
#获取数据及总大小
确定模板以后直接往里面填东西就可以了:
from torch.utils.data import Dataset#Dataset是必须要继承的
import pandas as pd
from PIL import Image
class LoadData(Dataset):
def __init__(self,image_path,transform=None):
self.imgs_info=pd.read_csv(image_path)
def __getitem__(self,index):
img_path,label=self.imgs_info['img_path'],self.imgs_info['weather']
img=Image.open(img_path)#得到路径需要打开图片
img=img.convert('RGB')#将图片转为张量
if transform is not None:
img=transform(img)#图像变换
returnimg,label
def __len__(self):
return len(self.imgs_info)
主函数中调用:
from torchvision import transforms
train_csv_path=r'./dataset/train.csv'
train_transforms=transforms.Compose(
[transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
train_dataset=LoadData(train_csv_path,transform=train_transforms)
train_loader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=True)