深度学习【30】pytorch 自定义数据迭代器

class myImageFloder(data.Dataset):
    def __init__(self,root,list,transform):#list:训练数据列表,比如:train.txt文件,每一行是一个训练样本
        fh = open(list)
        imgs=[]
        for line in fh.readlines():
            imgs.append(line.strip('\n'))
        self.root =root
        self.imgs = imgs
        self.transform=transform
        #transform:可以用来数据扩充,以及其他一系列操作。
        #本质上是通过transform对输入数据经过一系列的函数处理。
        #详细的可参考pytorch的transform
        #这里只定义img的transform函数,也可以定义一个label的transform函数。
    def __getitem__(self, index):
        name = self.imgs[index]
        img,label = getItem(name)#根据数据组织情况获取第index样本的img和标签
        if self.transform!=None:
            img = self.transform(img)
        return img,label

你可能感兴趣的:(深度学习)