Pytorch中提供一个了数据接口datasets,其中封装了很多公用数据集CIFAR10/100,ImageNet等,可以用下面的接口进行简单调用,那么如何使用Pytorch来加载我们自己制作好的trainset呢?我们从源码来找答案!
train_data = datasets.CIFAR10('./cifa10',train=True,transform=train_tranform,download=True)
从源码可以看到class cifar 继承了VisionDataset,VisionDataset是Dataset的子类,并实现了__init__,__len__,__getitem__,三个方法,事实上我们也可以想要实现自定义的数据接口,并使用pytorch进行训练很简单,只要继承基类Dataset并实现上述的三个方法就可以了。
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
对于加载自己的数据集,Pytorch中同样提供了一个接口,torchvision.datasets.ImageFolder ,但是这个接口相对局限一些,必须符合他的目录结构:/root/ids/*.jpg
def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
我们进行简单调试,看看这个方法都做了什么?
首先,我们可以看到我们输入的自定义目录self.root 是我们定义的训练集目录,首先进行__find_classes操作,我们来看看__find_classes 源码
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
返回值classes是一个列表,列表中包含着排好序的id也就是label,而class_to_ids是一个与之序号对应的字典,key是id,value是序号,如下
['102091655-1-201811011700-16', '10209231-1-201811010900-2', '1020962212-2-201811010900-24', '1020966131-3-201811011700-0', '102097752-0-201811010900-6']
{'1020962212-2-201811010900-24': 2, '1020966131-3-201811011700-0': 3, '102097752-0-201811010900-6': 4, '10209231-1-201811010900-2': 1, '102091655-1-201811011700-16': 0}
接下来,用samples接收make_dataset的返回值,其中extensions表示Pytorch支持的图片编码格式,与is_valid_file用于验证数据的合法性。
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
samples样例如下,是很多个tuple组成的list存储每个图片和对应的label
[('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000702_crop16.jpg', 0),
('test/102091655-1-201811011700-16/10.209.16.55-1-201811011700-201811011703_00000880_crop16.jpg', 0),
('test/10209231-1-201811010900-2/10.209.23.1-1-201811010900-201811010903_00000092_crop2.jpg', 1),
('test/1020962212-2-201811010900-24/10.209.62.212-2-201811010900-201811010903_00000756_crop24.jpg', 2),
('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000295_crop0.jpg', 3),
('test/1020966131-3-201811011700-0/10.209.66.131-3-201811011700-201811011703_00000302_crop0.jpg', 3),
('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000395_crop6.jpg', 4),
('test/102097752-0-201811010900-6/10.209.77.52-0-201811010900-201811010903_00000434_crop6.jpg', 4)]
接下来,还有一个loader的赋值操作,是一个函数参数,通常我们使用pil_loader函数进行加载。
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
get_item 是Dataloader的调度基础,输入参数是index索引,返回的是经过transform过的图片和label,len函数返回的是数据集的length
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
如果你可以看懂这几个函数的用法,就可以开始定义自己需要的数据接口了。假设我们的train.txt ,val.txt,test.txt 中的格式如下,想一下我们该如何自定义上文中的三种方法呢?
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002532_crop23.jpg 1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002521_crop23.jpg 1
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002535_crop23.jpg 2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002528_crop23.jpg 2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002523_crop23.jpg 2
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002529_crop23.jpg 3
/20190424/200001320002208-1556076900-23/CJ145YWJMK1-32130200001320002208-1556076900_00002527_crop23.jpg 3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000833_crop69.jpg 3
/20190424/200001320002208-1556067600-69/CJ145YWJMK1-32130200001320002208-1556067600_00000834_crop69.jpg 4
/20190424/00001320000179-1556104800-30/SZ009SZZP3-32130200001320000179-1556104800_00001954_crop30.jpg 4
下面是我给的伪代码,没有调试,主要是为了说明这个道理!
# _*_ coding:utf-8 _*_
import torch.utils.data as data
class trueData(data.Dataset):
def __init__(self,root,txt_path,dataset=None,transforms = None,loader=default_loader):
with open(txt_path) as data_input:
lines = data_input.readlines()
self.images = [os.path.join(root,line.split('\t')[0]) for line in lines]
self.labels = [os.path.join(root,line.split('\t')[1]) for line in lines]
self.transform = transforms
self.dataset = dataset
self.loader = loader
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img = self.images[index]
label = self.labels[index]
img_data = self.loader(img)
if self.transform:
try:
img = self.transform(img)
except:
print "error in transform"
return img,label
调用方法可以这么写,这样就完成了自定义数据的加载过程。
image_datasets = {x: customData(img_path='/home/badoo/person',
txt_path=('/home/badoo/train_list/' + x + '.txt'),
data_transforms=data_transforms,
dataset=x) for x in ['train', 'val']}
在我们训练过程中,前面有讲过通常输入的是tensor格式[N,C,W,H],在Pytorch中提供了一个API批量加载 DataLoader,并将结果进行transform和toTensor()以及BatchNorm等操作,源代码可供参考
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True) for x in ['train', 'val']}
参数部分
1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。
2、batch_size,根据具体情况设置即可。
3、shuffle,一般在训练数据中会采用。
4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。
5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。
6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。
7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。
8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。
9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
下面是两种接口调用方法,我更喜欢第2种 ^_^
#写法1:
train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
...
#写法2
train_load = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True,num_workers=8)
for i,(ids,labels) in enumerate(train_load):
...
坚持一件事或许很难,但坚持下来一定很酷!^_^