torchvison.datasets是torch.utils.data.Dataset的实现。
包括如下数据集:
all = (‘LSUN’, ‘LSUNClass’,
‘ImageFolder’, ‘DatasetFolder’, ‘FakeData’,
‘CocoCaptions’, ‘CocoDetection’,
‘CIFAR10’, ‘CIFAR100’, ‘EMNIST’, ‘FashionMNIST’, ‘QMNIST’,
‘MNIST’, ‘KMNIST’, ‘STL10’, ‘SVHN’, ‘PhotoTour’, ‘SEMEION’,
‘Omniglot’, ‘SBU’, ‘Flickr8k’, ‘Flickr30k’,
‘VOCSegmentation’, ‘VOCDetection’, ‘Cityscapes’, ‘ImageNet’,
‘Caltech101’, ‘Caltech256’, ‘CelebA’, ‘SBDataset’, ‘VisionDataset’,
‘USPS’, ‘Kinetics400’, ‘HMDB51’, ‘UCF101’, ‘Places365’)
import torch
import torchvision
from PIL import Image
cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)
for i, data in enumerate(cifarLoader, 0):
print(data[i][0])
# PIL
img = transforms.ToPILImage()(data[i][0])
img.show()
break
import torch.utils.data as data
#定义myDataSet类来继承Dataset
#generate train_data or test_data...
def default_loader(path):
return Image.open(path).convert('RGB')
class myDataSet(data.Dataset):
""""
@:param
label_txt:每个图像名称以及路径,one image one line
"""
def __init__(self,label_txt,transform = None,target_transform = None, loader=default_loader):
super(myDataSet, self).__init__()
self.imgs = []
self.transform =transform
self.target_transform = target_transform
self.loader =loader
fn = open(label_txt,'r')
imgs=[]
for line in fn:
line = line.strip('\n')
line = line.rstrip('\n')
words = line.split()
imgs.append(words[0])
self.imgs = imgs
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
fn = self.img[index]
img = self.loader(os.path.join(self.root,fn))
return img
label_txt的格式如下:
每一行是一个图像的绝对路径
同时,需要重写__len__与__getitem__两个函数如上
def get_my_data():
train_data = myDataSet(label_txt='',transforms=transform.ToTensor())
test_data = myDataSet(label_txt='', transforms=transform.ToTensor())
train_loader = DataLoader(train_data,shuffle=True,batch_size=BATCH_SIZE,num_workers=1)
#test_loader = DataLoader(test_data, shuffle=False, batch_size=BATCH_SIZE, num_workers=1)
return train_loader
参考文献:
https://blog.csdn.net/sinat_42239797/article/details/90641659
https://zhuanlan.zhihu.com/p/27434001