如何用Pytorch读取自己的数据集

在训练经典的数据集如cifar10,minsit等,可以用官方自带的数据集格式几行就写出来,如果是自己下载的数据集,那么我们应该如何用pytorch来读取呢?其实是有模板可以直接仿照着写的。

本次案例采用的是pokeman数据集,并用该数据集进行分类。该数据如下所示:
如何用Pytorch读取自己的数据集_第1张图片
如何用Pytorch读取自己的数据集_第2张图片
其中文件夹的名字便是标签。数据集大小划分为:皮卡丘 234、超梦239、杰尼龟223、小火龙 238、妙蛙种子234张图。

在深度学习中一般的流程是:加载数据—>构建模型—>训练和测试。

读取数据

在pytorch读取数据,采用3个步骤

  1. 继承torch中的通用的母类:torch.utils.data.Dataset
from torch.utils.data.Dataset
  1. __len __:这里需要返回定义数据的数量,返回整型数字
  2. __getitem __ :这里返回样本、标签等
一个简单的例子
from torch.utils.data import Dataset, DataLoader
class NumberDataset(Dataset):   #首先要继承Dataset母类
    def __init__(self, training=True):  #区分训练和测试
        if training:
            self.samples = list(range(1, 1001))   #加载数据,一般是存放数据的地址,不然内存爆炸
        else:
            self.samples = list(range(1001, 15001))

    def __len__(self):
        return len(self.samples)    #

    def __getitem__(self, idx):  # idx 是位置标号,在len(self.samples) 内,一个一个的读取该位置数据
        return self.samples[idx]

小结:1、首先得到所有的数据的地址名字(训练或测试);2、给出数据集长度;3、返回指定位置的数据内容,可以在该数据上进行任何预处理操作。

现在读取本次给的pokeman数据集

python代码框架为:

from torch.utils.data import Dataset, DataLoader  #自定义的母类,必须的
class Pokemon(Dataset):
    def __init__(self):        #去读数据路径
    	super(Pokemon, self).__init__()
    	pass
    def __len__(self):  #返回数据长度
    	pass
    def __getitem__(self, idx):  #返回当前位置的数据和标签
    	pass

接下来就是填充每一块函数里面的内容了。

1 将标签转化数字,且数据地址及其标签保存csv文件

首先需要加载数据和标签,因为标签需要转化成0,1,2,3,4,最好保存为csv文件,下次便可以直接加载csv文件。因此我们需要事先写一个函数保存csv文件,不写也可以,最好是写成csv。

下面这个函数可以单独写成一个文件,也可以放在class Pokemon(Dataset)里面。

	def load_csv(self, filename):
	    if not os.path.exists(os.path.join(self.root, filename)): 
	     #如果没有保存csv文件,那么我们需要写一个csv文件,如果有了直接读取csv文件
	        images = []
	        for name in self.name2label.keys():   
	            # 'pokemon\\mewtwo\\00001.png
	            images += glob.glob(os.path.join(self.root, name, '*.png'))
	            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
	            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
	
	        # 1167, 'pokemon\\bulbasaur\\00000000.png'
	        print(len(images), images)
	        
	        random.shuffle(images)
	        with open(os.path.join(self.root, filename), mode='w', newline='') as f:
	            writer = csv.writer(f)
	            for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
	                name = img.split(os.sep)[-2]        #从名字就可以读取标签
	                label = self.name2label[name]
	                # 'pokemon\\bulbasaur\\00000000.png', 0
	                writer.writerow([img, label])  #写进csv文件
	            print('writen into csv file:', filename)
	
	    # read from csv file
	    images, labels = [], []
	    with open(os.path.join(self.root, filename)) as f:
	        reader = csv.reader(f)
	        for row in reader:
	            # 'pokemon\\bulbasaur\\00000000.png', 0
	            img, label = row
	            label = int(label)
	            images.append(img)
	            labels.append(label)
	    assert len(images) == len(labels)
	    return images, labels
2 初始化函数

上面函数可以得到数据地址及其标签,接下来就是初始化,得到数据地址名和标签保存

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {}  # "sq...":0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys()) #将英文标签名转化数字0-4
        # print(self.name2label)
        # image, label
        self.images, self.labels = self.load_csv('images.csv')  #csv文件存在 直接读取
        if mode == 'train':  # 60%                   
            self.images = self.images[:int(0.6 * len(self.images))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 20% = 60%->80%
            self.images = self.images[int(
                0.6 * len(self.images)):int(0.8 * len(self.images))]
            self.labels = self.labels[int(
                0.6 * len(self.labels)):int(0.8 * len(self.labels))]
        else:  # 20% = 80%->100%
            self.images = self.images[int(0.8 * len(self.images)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]
3 总体样本数量
    def __len__(self):
        return len(self.images)
4 取出当前位置的数据内容和标签等
    def __getitem__(self, idx):
        # idx~[0~len(images)]
        # self.images, self.labels
        # img: 'pokemon\\bulbasaur\\00000000.png'
        # label: 0
        img, label = self.images[idx], self.labels[idx]
        
        tf = transforms.Compose([   #常用的数据变换器
					            lambda x:Image.open(x).convert('RGB'),  # string path= > image data 
					            #这里开始读取了数据的内容了
					            transforms.Resize(   #数据预处理部分
					                (int(self.resize * 1.25), int(self.resize * 1.25))), 
					            transforms.RandomRotation(15), 
					            transforms.CenterCrop(self.resize), #防止旋转后边界出现黑框部分
					            transforms.ToTensor(),
					            transforms.Normalize(mean=[0.485, 0.456, 0.406],
					                                 std=[0.229, 0.224, 0.225])
       							 ])
        img = tf(img)
        label = torch.tensor(label)  #转化tensor
        return img, label       #返回当前的数据内容和标签
5 加载一个bathsize数据

完成上面的步骤,我们只能得到一个一个数据,且需用迭代器表示,即iter:

    db = Pokemon('pokemon', 64, 'train')
    x, y = next(iter(db))
    print('sample:', x.shape, y.shape, y)

因此还需要DataLoader来加载批量的数据:

  loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
  for x, y in loader: #此时x,y是批量的数据
  	pass 
6 可视化数据集

当我们完成数据集读取部分,可视化也是必须的。我们采用的是visdom来可视化。

    import visdom
    import time
    for x, y in loader:
	    viz.images(
		        db.denormalize(x), #因为对原始数据归一化,所以可视化需要返回去,该函数需要自己写下。
		        nrow=8,  #每行显示8张图
		        win='batch',
		        opts=dict(title='batch'))
	    viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
	    time.sleep(10)

如果visdom连接超时,那么需要:

>python -m visdom.server

可以在网页上显示:
如何用Pytorch读取自己的数据集_第3张图片

    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x-mean)/std
        # x = x_hat*std = mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        # print(mean.shape, std.shape)
        x = x_hat * std + mean
        return x
7 简单的文件分级,可以用一行代码搞定

如果文件结构是二级目录,且代码和文件夹在同一个目录:
在这里插入图片描述如何用Pytorch读取自己的数据集_第4张图片
那么可以用一行代码来写:

    tf = transforms.Compose([
		                    transforms.Resize((64,64)),
		                    transforms.ToTensor(),
		   					 ])
    db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf) 
    loader = DataLoader(db, batch_size=32, shuffle=True)

    print(db.class_to_idx)
    for x,y in loader:
        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
        time.sleep(10)

用ImageFolder即可以写,不过该情况受限,因此不建议。还是用前面的函数自己去定义,方便对数据修改,或者额外引入标签。

接下来就是如何训练了,可参考我写的训练模板:https://blog.csdn.net/lifei1229/article/details/105530012
https://blog.csdn.net/lifei1229/article/details/105527312

你可能感兴趣的:(pytorch,深度学习,机器学习)