从1制作数据集制作DataLoader

代码部分采用jupyter notebook编写
首先建立了cat 和dog两个文件夹,从网上随便一类找了一张图片放进去了
代码主要如下
主要思路是建立一个csv文件,每一类图片通过简单的数据增强(本文选了旋转)将猫狗图片分别增加到十张,然后将猫的图片赋予标签0,狗的图片赋予标签1,写入csv文件中,然后制作自定义的DataLoader,主要是写好init ,getitem以及len三个类,其实就是将image和label的单个列表逐一读取,后续再利用dataloader实现迭代的数据循环,最后测试了 一下dataloader里的图片和标签正不正确。

csv_dir=os.path.join('E:\教程\CSDNlabel.csv','label'+'.csv')
dir1='E:\教程\CSDN'
dir2=os.listdir(dir1)
def showpath():
    for i in dir2:
        dir3=os.path.join(dir1,i)
        dir4=os.listdir(dir3)
        for j in range(len(dir4)):
            dir5=os.path.join(dir3,dir4[j])    
#            print(dir5)
            img=Image.open(dir5)
            if i =='cat':
                for k in range(10):
                    rotation_img=img.rotate(20*k)
                    dir6=os.path.join(dir3,('cat'+str(k)+'.jpg'))
               #     print(dir6)
                    rotation_img.save(dir6)
            if i =='dog':
                for k in range(10):
                    rotation_img=img.rotate(20*k)
                    dir6=os.path.join(dir3,(i+str(k)+'.jpg'))
                    rotation_img.save(dir6)
    with open(csv_dir,'w',newline='') as f:
        for j in range(len(dir2)):
            newdir0=os.path.join(dir1,dir2[j])
            newdir1=os.listdir(newdir0)
            if j==0:
                for k in range(len(newdir1)):
                    dict1={}
                    newdir2=os.path.join(newdir0,newdir1[k])
                    dict1[newdir2]='0'
                   # print(dict1)                
                    example=[]
                    for m in dict1:
                        example.append(m)
                        example.append(dict1[m])
                        writer=csv.writer(f)
                        writer.writerow(example)
                     
               
            if j==1:
                for k in range(len(newdir1)):
                    dict1={}
                    newdir2=os.path.join(newdir0,newdir1[k])
                    dict1[newdir2]='1'
                   # print(dict1)                
                    example=[]
                    for m in dict1:
                        example.append(m)
                        example.append(dict1[m])
                        writer=csv.writer(f)
                        writer.writerow(example)

def default_loader(path):
    return Image.open(path)
class Dataset():
    def __init__(self,loader=default_loader,transform=None):
        with open(csv_dir,'r') as f:
            imgs=[]
            for line in f:
                line=line.strip('\n')
          #      line=line.rstrip('\n')
                line=line.split(',')
                imgs.append((line[0],int(line[1])))
        self.imgs=imgs
        self.loader=loader
        self.transform=transform
    def __len__(self):
        return len(self.imgs)
    def __getitem__(self,index):
        images,labels=self.imgs[index]
        img=self.loader(images)
        img=self.transform(img)
        return img,labels
train_transform=transforms.Compose([transforms.Resize(280),transforms.CenterCrop(256),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
trainset=Dataset(transform=train_transform)

trainloader=torch.utils.data.DataLoader(trainset,batch_size=2,shuffle=True)
                    
showpath()
import torchvision
import matplotlib.pyplot as plt
image,labels=next(iter(trainloader))
image=torchvision.utils.make_grid(image)
image=image.numpy().transpose(1,2,0)
print([int(labels[i].numpy()) for i,label in enumerate(labels)])
plt.imshow(image)

plt.show()

你可能感兴趣的:(从1制作数据集制作DataLoader)