pytorch用自己的数据集进行Dataloader,并对其划分数据集

最近在学pytorch,今天晚上用pytorch的数据加载部分,一开始一直在纠结怎么划分数据集,后来还是手动分了,开始是用torch.utils.data.random_split但是后来一直报错,我也不知道哪里有错,解决不了,后来暴力解决了

1.重写dataset类,这是必须要写的

  • 主要继承Dataset类,重写__getitem__,and __len__的方法
  • 我的问题:针对一个文件夹有n张图片,然后一个csv文件中有每个图片对应的label,具体样式如下
  • pytorch用自己的数据集进行Dataloader,并对其划分数据集_第1张图片

步骤1:将image和label对应加载到一个数据集中

class SkinDataset(Dataset):
    def __init__(self,csv_file,root_dir,transform=None):
        self.csv=pd.read_csv(csv_file)
        self.root_dir=root_dir
        self.transform=transform
    def __len__(self):
        return len(self.csv)
    def __getitem__(self,idx):
        image_path=os.path.join(self.root_dir+self.csv.ix[idx,0]+'.jpg')
        image=io.imread(image_path)
        label=self.csv.ix[idx,1:].as_matrix()
        label=label.reshape(-1,1)
        sample={"image":image,"label":label}
        return sample

步骤2:将得到的数据集划分为train和test

  • 得到整个数据集
    dataset=SkinDataset(csv_file="...",root_dir="...")
  • 划分数据集
train_size=0.8*len(dataset)
#因为得到的dataset是一个数组字典,所以只能一个个往数组里添加
train_dataset=[]
teat_train=[]
for i in range(train_size):
	train_dataset.append(dataset[i])
for i in range(train_size,len(dataset)):
	test_dataset.append(dataset[i])
	

步骤3:对不同数据进行transform,并将其加载到dataloader中

train_transform=transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_trainsform=transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class Dataset2(Dataset):
    def __init__(self,dataset,transform):
        self.dataset=dataset
        self.transform=transform
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        img,label=self.dataset
        return img,label
#因为要将train和test分别进行trainsform,所以只能重新写一个类进行transform,实在想不到好办法了
train_dataset2=Dataset2(train_dataset,transform=train_transform)
test_dataset2=Dataset2(test_dataset,transform=val_trainsform)
traindata=DataLoader(train_dataset2,batch_size=32,shuffle=True,num_workers=4)
traindata=DataLoader(test_dataset2,batch_size=32,shuffle=True,num_workers=4)

哈哈哈,终于不报错了,是的random_split太坑了,出坑太难

你可能感兴趣的:(pytorch)