pytorch dataloader 和 dataset 数据加载的研究

一 pytorch 数据加载的研究

目录

  • 一 pytorch 数据加载的研究
  • 一、dataloader and dataset?
  • 二、类的实例化
    • 1.继承Dataset
    • 2.重写父类函数
  • 3.实例化
  • 总结

一、dataloader and dataset?

Dataset抽象类,所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。

DataLoader(): 迭代器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

二、类的实例化

大多数文章,并没有仔细探究Dataset这个类,究竟是怎么一步步完成数据和标签的加载的

first of all ,它是个类
所以,从类的角度,继承,重写,实例化,这个面向对象的思路,先研究一下

1.继承Dataset

代码如下(示例):xxx代表可以自己定义的内容

class myDataset(Dataset):
    def __init__(self, xxx):
     
    def __getitem__(self,index):
            return xxx,xxx
            
    def __len__(self):
        return len(xxx)

可见,getitem 和 len 需要自己重写,并返回一些东西

2.重写父类函数

这里,采用了kaggle 的Dog Breed Identification项目的数据
是个分类任务,使用resnet vgg 就可以解决
数据集包含 3个文件
train(文件夹)
test(文件夹)
label.csv
pytorch dataloader 和 dataset 数据加载的研究_第1张图片
可以到官网看 https://www.kaggle.com/competitions/dog-breed-identification/
代码如下(示例):

from torch.utils.data import Dataset
import pandas as pd
import cv2
class myDataset(Dataset):
    def __init__(self, dogdir):
        self.imgset =  dogdir["id"]
        self.labelset = dogdir["breed"]
        dog_breeds = sorted(list(set(self.labelset )))
        n_classes = len(dog_breeds)
        self.class_to_num = dict(zip(dog_breeds, range(n_classes)))
        
    def __getitem__(self,index):
            imgpath = "train/"+self.imgset[index] + ".jpg"
            img = cv2.imread(imgpath)
            labelname  = self.labelset[index]
            labelhot =  self.class_to_num.get(labelname)
            return img, labelhot
    
    def __len__(self):
        return len(self.imgset)

3.实例化

看看继承后的dataset

df = pd.read_csv('labels.csv')  #使用pandas读取csv  
myd = myDataset(df)
img,label = myd.__getitem__(4) #指定4这个item
lenth = myd.__len__()
#print(img)
print(label)
print(lenth)

49 #one hot 编码后的标签
10222 # 总体数量

总结

提示:这里对文章进行总结:

例如:以上就是今天要讲的内容

你可能感兴趣的:(硕士阶段,pytorch)