pytorch加载自带数据集以及个人数据集的方式

pytorch加载数据集

  • 一、加载pytorch自带数据集
    • 1.使用torchvision.datasets加载数据集
    • 2.使用torch.utils.data.DataLoader来实例化
    • 3.测试
  • 二、加载个人的数据集
    • 1.继承Dataset类,生成数据集
    • 2.加载数据集

一、加载pytorch自带数据集

torchvison.datasets是torch.utils.data.Dataset的实现。
包括如下数据集:
all = (‘LSUN’, ‘LSUNClass’,
‘ImageFolder’, ‘DatasetFolder’, ‘FakeData’,
‘CocoCaptions’, ‘CocoDetection’,
‘CIFAR10’, ‘CIFAR100’, ‘EMNIST’, ‘FashionMNIST’, ‘QMNIST’,
‘MNIST’, ‘KMNIST’, ‘STL10’, ‘SVHN’, ‘PhotoTour’, ‘SEMEION’,
‘Omniglot’, ‘SBU’, ‘Flickr8k’, ‘Flickr30k’,
‘VOCSegmentation’, ‘VOCDetection’, ‘Cityscapes’, ‘ImageNet’,
‘Caltech101’, ‘Caltech256’, ‘CelebA’, ‘SBDataset’, ‘VisionDataset’,
‘USPS’, ‘Kinetics400’, ‘HMDB51’, ‘UCF101’, ‘Places365’)

1.使用torchvision.datasets加载数据集

import torch
import torchvision
from PIL import Image

cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)

2.使用torch.utils.data.DataLoader来实例化

cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)

3.测试

for i, data in enumerate(cifarLoader, 0):
    print(data[i][0])
    # PIL
    img = transforms.ToPILImage()(data[i][0])
    img.show()
    break

二、加载个人的数据集

1.继承Dataset类,生成数据集

import torch.utils.data as data
#定义myDataSet类来继承Dataset

#generate train_data or test_data...
def default_loader(path):
    return  Image.open(path).convert('RGB')

class myDataSet(data.Dataset):
    """"
    @:param
    label_txt:每个图像名称以及路径,one image one line
    """
    def __init__(self,label_txt,transform = None,target_transform = None, loader=default_loader):
        super(myDataSet, self).__init__()
        self.imgs = []
        self.transform =transform
        self.target_transform = target_transform
        self.loader =loader
        fn = open(label_txt,'r')
        imgs=[]
        for line in fn:
            line  = line.strip('\n')
            line = line.rstrip('\n')
            words = line.split()
            imgs.append(words[0])
        self.imgs = imgs

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        fn = self.img[index]
        img = self.loader(os.path.join(self.root,fn))
        return  img

label_txt的格式如下:
每一行是一个图像的绝对路径
同时,需要重写__len__与__getitem__两个函数如上
pytorch加载自带数据集以及个人数据集的方式_第1张图片

2.加载数据集

def get_my_data():
    train_data = myDataSet(label_txt='',transforms=transform.ToTensor())
    test_data = myDataSet(label_txt='', transforms=transform.ToTensor())
    train_loader = DataLoader(train_data,shuffle=True,batch_size=BATCH_SIZE,num_workers=1)
    #test_loader = DataLoader(test_data, shuffle=False, batch_size=BATCH_SIZE, num_workers=1)
    return train_loader

参考文献:
https://blog.csdn.net/sinat_42239797/article/details/90641659
https://zhuanlan.zhihu.com/p/27434001

你可能感兴趣的:(研发管理,pytorch,深度学习,python)