1、所使用数据集已被集成在pytorch内,如:CIFAR-10,CIFAR-100,MNIST等等。对于这种数据集,可以直接使用pytorch内置函数:torchvision.datasets.CIFAR100来直接加载,比较方便。例程如下:
transform_train = transforms.Compose([
#transforms.ToPILImage(),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True,
download=True, transform=transform_train)
cifar100_training_loader = DataLoader(
cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
2、所使用数据集为被集成,这个类别是本文的主要讲述内容。
对于自定义数据集pytorch实际上是有一个函数的:torchvision.datasets.ImageFolder(),但是此函数只能加载特定形式的数据集(图片已被分类好,并放在相应文件夹下了,其标签就是其上层目录的名称,在下面会解释为什么)。有时候我们需要使用的数据集可能不是这样的,其标签可能不在目录上,而在一个独立的文件里。这时候,直接使用ImageFolder会导致训练结果与预期结果毫无关系,这就需要我们自己重新构造一个类似于DatasetFolder类(ImageFolder就继承了DatasetFolder类)的新类别来加载数据集。
pytorch中DatasetFolder类的官方实现如下:
class DatasetFolder(data.Dataset):
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
classes, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
其中三个地方需要修改:find_classes(root),make_dataset(root, class_to_idx, extensions),和sample = self.loader(path)。
find_classes(root)源码如下:
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
# 遍历dir目录下的所有子目录名称并将其存在classes中
classes.sort()
# 由于Python版本的不同可能需要更换为sorted(classes)
class_to_idx = {classes[i]: i for i in range(len(classes))}
# 创建一个字典,将类别与数字对应
return classes, class_to_idx
该函数的参数:dir = root,调用DatasetFolder类时使用的目录,多为travel,test等。通过上面的程序以及注释,可以看出该函数时默认我们将图片分好类放在相应的文件夹下,如果需要使用其他形式的数据集,首先就要修改该函数得到正确的类别对应表,当我们使用tiny-imagenet-200时,修改为:
def find_classes(class_file): # 此时class_file要是对应的标签对应表
with open(class_file) as r:
classes = map(lambda s : s.strip(), r.readlines())
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
当使用ImageNet时,修改为:
def find_classes(img_label): # 此时img_label要是对应的标签对应表
classes = scio.loadmat(img_label)['synsets']
class_to_idx = {}
for i in range(1000):
class_to_idx[label_array[i][0][1][0]] = i
return classes, class_to_idx
可以看到只是在读取文件的路径和形式有些许区别,在使用不同的数据集时需要我们自己去调整具体细节。
make_dataset(root, class_to_idx, extensions)源码如下:
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir) # 把path中包含的"~"和"~user"转换成用户目录
for target in sorted(os.listdir(dir)): # os.listdir()函数返回一个包含dir目录下所有文件或目录的列表
d = os.path.join(dir, target) # 将dir和target连接形成新的路径 d为类别目录
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions): # 判断fnames的后缀是否正确(JPEG.JPG等等)
path = os.path.join(root, fname) # 得到文件的路径和文件名
item = (path, class_to_idx[target]) # 依据class_to_idx得到图片类别对应的数字
images.append(item)
return images
sample = self.loader(path)注意对于loader,源码中使用的是default_loader,我在实现时,其相关函数一直无法导入。因此我使用img = Image.open(data).convert(‘RGB’)来替代。
以上就是加载数据集的方法,其他方法大同小异,可以依次类推。
参考链接:
https://github.com/budlbaram/tiny_imagenet/blob/master/tiny_imagenet.py
pytorch官方源码