我们经常见到这样一段代码:
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
其中,datasets 是 torchvision 的一个模块,通过它可以导入各种像 MINIST 等常用的数据集; torch.utils.data.DataLoader 是 torch 提供的一个有关划分的模块。
对应 自定义数据集,我们可以使用 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 两个模块对数据集处理
类似 torchvision 中的 datasets, torch.utils.data.Dataset 可以加载数据集,并对数据经行必要的 transform,Dataset 的官方说明如下:
An abstract class representing a Dataset.
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite getitem(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite len(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.
Dataset 类似于 C++ 的虚基类,其中的函数无具体定义,需要重载才能使用,我们需要重载以下方法:
通常会使用一个 csv 存储样本的路径和类别,训练集和测试机分别对应一个 csv 文件:
此时在 Dataset 的 __init__(self)
中需要传入 csv_path, 以及数据集需要的做的 transform,接着读取 csv 中的样本 X 的路径,以及类别信息,计算样本数目
import torch
import argparse
import numpy as np
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class ReadDataFromCSV(Dataset):
def __init__(self, csv_path, transform):
# Transforms
self.transform = transform
# read csv
self.data_info = pd.read_csv(csv_path, header=None)
self.image_arr = np.asarray(self.data_info.iloc[1:, 0])
self.label_arr = np.asarray(self.data_info.iloc[1:, 1])
self.data_len = len(self.data_info.index)
接着,重载 __getitem__
函数
self.transform()
对图片预处理(裁剪、缩放、转换为 tensor、归一化等) def __getitem__(self, index):
# get image
single_img_name = self.image_arr[index]
single_img_img = Image.open('../' + single_img_name)
single_img_tensor = self.transform(single_img_img)
# get label
single_image_label = self.label_arr[index]
return (single_img_tensor, single_image_label)
注意:csv_path
需要根据具体代码和 csv 的相对路径修改,图片相对代码的位置和csv中给出位置也要自适应调整一下!
然后,重构 __len__(self)
def __len__(self):
return self.data_len
接着测试一下,使用
if __name__ == '__main__':
csv_path = '../SARimage/train.csv'
torch_data = ReadDataFromCSV(
csv_path='../SARimage/train.csv',
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
)
)
for i, (X, y) in enumerate(torch_data):
print(f'index:{i}\nX:{X}\ny:{y}')
break
上面使用了 transforms
的 Resize、ToTensor、Normalize([0.5], [0.5]) 对数据进行变换,输出结构如下:
index:0
X:tensor([[[-0.7333, -0.7176, -0.7725, -0.7098, -0.7255, -0.7569, -0.7412,
-0.7490, -0.7255, -0.7961, -0.7804, -0.7333, -0.7569, -0.7255,
-0.7647, -0.7804, -0.7882, -0.7569, -0.7490, -0.8118, -0.7882,
-0.7961, -0.7961, -0.7725, -0.7020, -0.6549, -0.6627, -0.7569],
[-0.7569, -0.7333, -0.7725, -0.7412, -0.7412, -0.8039, -0.7569,
-0.6235, -0.7569, -0.7255, -0.7412, -0.7098, -0.7255, -0.7020,
-0.6471, -0.6784, -0.6941, -0.6471, -0.6784, -0.8039, -0.8196,
-0.8118, -0.7961, -0.7490, -0.7490, -0.7020, -0.6784, -0.7333],
...
[-0.7725, -0.7961, -0.7333, -0.7647, -0.7804, -0.8353, -0.8275,
-0.7725, -0.7961, -0.8118, -0.8118, -0.8039, -0.8118, -0.8039,
-0.8118, -0.7961, -0.7176, -0.7647, -0.7647, -0.7412, -0.7804,
-0.7647, -0.7569, -0.7725, -0.8510, -0.8353, -0.7961, -0.7647]]])
y:2S1
具体使用 DataLoader 划分数据集方法如下:
dataloader = torch.utils.data.DataLoader(
ReadDataFromCSV(
csv_path = '../../SARimage/train.csv',
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
drop_last=False,
)
当没有 csv 文件时,可以使用 pathlib2 读取文件夹,遍历其中的所有文件,用文件夹名称作为label;
当我们需要使用 label 时,如果 label 是一个 str,则需要将 label 变换成数字类型,比如使用一个 dict 映射:
self.label_dict = {'2S1': 0, 'BMP2': 1, 'BRDM_2': 2,
'BTR_60': 3, 'BTR70': 4, 'D7': 5, 'T62': 6, 'T72': 7, 'ZIL131': 8, 'ZSU_23_4': 9}
single_image_label = self.label_dict[self.label_arr[index]]
在 CGAN 中,我们可以这样子做:
class ReadDataFromCSV(Dataset):
def __init__(self, csv_path, transform):
# Transforms
self.transform = transform
# read csv
self.data_info = pd.read_csv(csv_path, header=None).iloc[1:, :]
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
self.data_len = len(self.data_info.index)
self.label_dict = {'2S1': 0, 'BMP2': 1, 'BRDM_2': 2,
'BTR_60': 3, 'BTR70': 4, 'D7': 5, 'T62': 6, 'T72': 7, 'ZIL131': 8, 'ZSU_23_4': 9}
def __getitem__(self, index):
# get image
single_image_name = self.image_arr[index]
single_img_img = Image.open('../../' + single_image_name)
single_img_tensor = self.transform(single_img_img)
# get label
single_image_label = self.label_dict[self.label_arr[index]]
return (single_img_tensor, single_image_label)
def __len__(self):
return self.data_len
# 读取
dataloader = torch.utils.data.DataLoader(
ReadDataFromCSV(
csv_path = '../../SARimage/train.csv',
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
drop_last=False,
)
REFERENCES: