最近在做李沐发布的Kaggle树叶分类竞赛,在处理和加载Classify-Leaves数据集的时候遇到了一些问题,真是巧妇难为无米之炊啊,现在记录下来,希望可以帮助到更多的初学者!
数据集由一个images文件和三个csv文件组成
打开train.csv文件可以看到图片类别的标签是字符串,不满足深度学习训练的要求:
train_file = pd.read_csv('../classify-leaves/train.csv') # 读取的文件是DataFarame格式
data = train_file['label'].values # 将文件中label这一列拿出来 把格式转换为numpy格式
data_list = list(set(data)) # 去除重复项,转换为list形式
dict = {}
for i in range(len(data_list)):
dict[data_list[i]] = i # 构造映射字典
train_file['label_number'] = train_file['label'].map(dict) # 利用映射将字符串类别转换为数字
train_file.to_csv('../classify-leaves/train_.csv', index=False) # 将修改的文件保存到本地 不保存序号
修改后的train.csv文件: 一般官方数据集是通过torchvison.datasets读取,然后用DataLoader函数加载,例如MNIST等,但是Classify-Leaves数据集不可以直接通过torchvison.datasets读取,因此我们需要定义自己的Dataset。
代码如下(示例):
class Dataset(object):
#这个函数就是根据索引,迭代的读取路径和标签
def __init__(self, csv_file, root_dir, transform=None):
raise NotImplementedError
def __getitem__(self, index):
raise NotImplementedError
#返回数据的长度
def __len__(self):
raise NotImplementedError
代码如下:
class MyDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
"""
csv_file: 标签文件的路径.
root_dir: 所有图片的路径.
transform: 一系列transform操作
"""
self.data_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.data_frame) # 返回数据集长度
def __getitem__(self, idx):
img_path = os.path.join(self.root_dir,
self.data_frame.iloc[idx, 0]) #获取图片所在路径
img = Image.open(img_path).convert('RGB') # 防止有些图片是RGBA格式
label_number = self.data_frame.iloc[idx, 2] # 获取图片的类别标签
if self.transform:
img = self.transform(img)
return img, label_number # 返回图片和标签
定义好Dataset后,我们就可以进行下一步了
#读取数据集
train_dataset = MyDataset(csv_file='../classify-leaves/train_.csv',
root_dir='../classify-leaves',
transform=torchvision.transforms.ToTensor())
train_iter = DataLoader(train_dataset, batch_size=128) # 加载数据集
for X, y in train_iter: # 迭代batchsize中的数据
print(X.shape) # torch.Size([128, 3, 224, 224])
print(y.shape) # torch.Size([128])
break
以上就是今天要讲的内容,本文仅仅简单介绍了针对特定的数据集如何构造合适的Dataset,以方便通过DataLoader函数来加载。