在进行深度学习任务时,一个完整的baseline通常分为以下几个部分:
Dataset
和DataLoader
这两个类。本文主要是对Pytorch中定义数据加载的方法做一个学习。
Dataset
是Pytorch中的一个数据读取类,它已经包含了很多常见的数据集,如下:
torchvision.datasets中包含了以下数据集
我们可以直接使用这个Dataset类里面的数据集,示例如下:
dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)
其中
Dataset
的定义如下:
class Dataset(object):
def __init__(self):
...
def __getitem__(self, index):
return ...
def __len__(self):
return ...
主要包含两个方法:
__getitem__()
__getitem__
函数的作用是根据索引index遍历数据,一般返回image的Tensor形式和对应标注。当然也可以多返回一些其它信息,这个根据需求而定。
__len__()
__len__
函数的作用是返回数据集的长度。
在我们训练自己的数据时,需要继承它,并需要重写__getitem__()
和__len__()
这两个方法。
示例如下:
class CarDataset(Dataset):
def __init__(self, img_df, transform=None):
self.img_df = img_df
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
# start_time = time.time()
# img = Image.open(self.img_df.iloc[index]['index']).convert('RGB')
img = cv2.imread(self.img_df.iloc[index]['filename'])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.transform is not None:
img = self.transform(image=img)
return img['image'], torch.from_numpy(np.array(self.img_df.iloc[index]['label']))
def __len__(self):
return len(self.img_df)
train_transform = Compose([
Resize(288,352),
HorizontalFlip(),
OneOf([
RandomContrast(),
RandomGamma(),
RandomBrightness(),
], p=0.3),
OneOf([
CLAHE(p=0.5),
GaussianBlur(3, p=0.3),
IAASharpen(alpha=(0.2,0.3), p=0.3),
], p=1),
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
RandomCrop(256, 320),
ToTensor()
])
最后一个是ToTensor
()。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
构建一个可迭代的数据装载器,可以理解为在训练过程中,DataLoader
将自定义的Dataset
根据batch size
大小、是否shuffle
等封装成一个又一个batch大小的Tensor,数据给模型进行训练测试。
即在DataLoder中,会触发Mydataset中的getiterm函数读取一张图片的数据和标签,并拼接成一个batch返回,作为模型真正的输入。
参数表如下:
常用的是
自定义示例:
train_loader = torch.utils.data.DataLoader(
CarDataset(train_label,
train_transform,
), batch_size=batch_size, shuffle=True, num_workers=work_num, pin_memory=True
)
https://blog.csdn.net/u014380165/article/details/79058479?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/
https://blog.csdn.net/g11d111/article/details/81504637?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task