前言
- 本文记录一下如何简单自定义pytorch中Datasets,官方教程
- 文件层级目录如下:
images
annotations_file.csv
数据说明
image
文件夹中有需要训练的图片,annotations_file.csv
中有2列,分别为image_id
和label
,即图片名和其对应标签。
image_id |
label |
1 |
风景 |
2 |
风景 |
3 |
风景 |
4 |
星空 |
5 |
星空 |
6 |
星空 |
7 |
人物 |
8 |
人物 |
9 |
人物 |
代码展示
导入必要包
import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T
自定义Datasets
- 自定义Datasets之前,首先我们需要准备两个信息:
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
image = Image.open(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
定义图片预处理方法
transform = {'train':T.Compose([
T.Resize((224,224)),
T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),])}
annotations_file = '/kaggle/input/datasets-test/annotations_file.csv'
img_dir = '/kaggle/input/datasets-test/images'
train_data = CustomImageDataset(annotations_file = annotations_file ,img_dir = img_dir, transform = transform['train'])
- 我们可以使用
len(train_data)
检查样本完整性,以及Datasets
定义正确性,这里输出9
,的确只有9张图片,正确无误。
使用DataLoaders加载数据
- 因为这里数据较少,所以设置
batch_size = 2
,打乱数据shuffle = true
,不丢弃数据drop_last=False
,有关DataLoader
的更多操作可以参照官方API
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
- 使用
iter()
函数和next()
函数,取1个batch
,检查数据
train_features, train_labels = next(iter(train_dataloader))
- 取第1个
batch
中的第1个图片,并将其可视化。
- 由于
train_features[0]
的维度为(1,3,224,224)
,所以使用squeeze()
函数从数组中删除单维度条目,即把为1的维度去掉。再使用permute()
函数将维度变换(224,224,3)
,便于plt
绘图。
img = train_features[0].squeeze()
img = img.permute(1,2,0)
plt.imshow(np.asarray(img))
plt.axis('off')
plt.show()
- 打印图片标签
print(train_labels[0])
,输出'人物'
在Datasets中将字符串标签数值化
- 我们发现上面打印出的标签为字符串,如果我们想要将其数值化,只需要在
Datasets
中__getitem__
部分改动一点
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
image = Image.open(img_path)
data_category, data_class = pd.factorize(self.img_labels.iloc[:, 1])
label = data_category[idx]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
- 这样就可以了,可以看到其实
array
格式的数据也是可以读取的,只要保证idx
一致,且对应就可以。
划分训练集与验证集
- 根据前面的说明,其实训练集与验证集的划分就变的很简单了,只需要4个列表/数组,
train_path
、train_label
、vaild_path
、vaild_label
分别表示训练集图片路径、标签、验证集图片路径、标签。DataSets
可以这样写:
class CustomImageDataset(Dataset):
def __init__(self, image_id, image_label, img_dir, transform=None, target_transform=None):
self.image_id = image_id
self.image_label = image_label
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, str(self.image_id[idx]) + '.jpg')
image = Image.open(img_path)
label = self.image.label[idx]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
transform = {'train':T.Compose([T.Resize((224,224)),T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),T.ToTensor(),]),
'valid':T.Compose([T.Resize((224,224)),T.ToTensor(),])}
train_data = CustomImageDataset(train_path, train_label, img_dir = './',transform = transform['train'])
valid_data = CustomImageDataset(valid_path, valid_label, img_dir = './',transform = transform['valid'])
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
valid_dataloader = DataLoader(valid_data, batch_size=2, shuffle=True, drop_last=False)