pytorch简单自定义Datasets

前言

  • 本文记录一下如何简单自定义pytorch中Datasets,官方教程
  • 文件层级目录如下:
    • images
      • 1.jpg
      • 2.jpg
      • 9.jpg
    • annotations_file.csv

数据说明

  • image文件夹中有需要训练的图片,annotations_file.csv中有2列,分别为image_idlabel,即图片名和其对应标签。
    pytorch简单自定义Datasets_第1张图片
image_id label
1 风景
2 风景
3 风景
4 星空
5 星空
6 星空
7 人物
8 人物
9 人物

代码展示

导入必要包

import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T

自定义Datasets

  • 自定义Datasets之前,首先我们需要准备两个信息:
    • 图片地址
    • 图片标签
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 读取包含图片id和图片标签的csv文件
        self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 读取图片标签
        label = self.img_labels.iloc[idx, 1]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

定义图片预处理方法

transform = {'train':T.Compose([
    # 将图片缩放至固定尺寸
    T.Resize((224,224)),
    # 应用CIFAR10自动增强策略
    T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
    # 像素值归一化,并转换为ternsor格式
    T.ToTensor(),])}
  • 接下来我们将Datasets实例化
annotations_file = '/kaggle/input/datasets-test/annotations_file.csv'
img_dir = '/kaggle/input/datasets-test/images'
train_data = CustomImageDataset(annotations_file = annotations_file ,img_dir = img_dir, transform = transform['train'])
  • 我们可以使用len(train_data)检查样本完整性,以及Datasets定义正确性,这里输出9,的确只有9张图片,正确无误。

使用DataLoaders加载数据

  • 因为这里数据较少,所以设置batch_size = 2,打乱数据shuffle = true,不丢弃数据drop_last=False,有关DataLoader的更多操作可以参照官方API
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
  • 使用iter()函数和next()函数,取1个batch,检查数据
train_features, train_labels = next(iter(train_dataloader))
  • 取第1个batch中的第1个图片,并将其可视化。
  • 由于train_features[0]的维度为(1,3,224,224),所以使用squeeze()函数从数组中删除单维度条目,即把为1的维度去掉。再使用permute()函数将维度变换(224,224,3),便于plt绘图。
# 维度整理
img = train_features[0].squeeze()
# 转置操作
img = img.permute(1,2,0)
# 绘图
plt.imshow(np.asarray(img))
# 关闭刻度线
plt.axis('off')
plt.show()

pytorch简单自定义Datasets_第2张图片

  • 打印图片标签print(train_labels[0]),输出'人物'

在Datasets中将字符串标签数值化

  • 我们发现上面打印出的标签为字符串,如果我们想要将其数值化,只需要在Datasets__getitem__部分改动一点
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 读取包含图片id和图片标签的csv文件
        self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 将字符串进行数值编码
        data_category, data_class = pd.factorize(self.img_labels.iloc[:, 1])
        label = data_category[idx]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • 这样就可以了,可以看到其实array格式的数据也是可以读取的,只要保证idx一致,且对应就可以。

划分训练集与验证集

  • 根据前面的说明,其实训练集与验证集的划分就变的很简单了,只需要4个列表/数组,train_pathtrain_labelvaild_pathvaild_label分别表示训练集图片路径、标签、验证集图片路径、标签。DataSets可以这样写:
class CustomImageDataset(Dataset):
    def __init__(self, image_id, image_label, img_dir, transform=None, target_transform=None):
        # 读取包含图片id
        self.image_id = image_id
        # 读取图片标签
        self.image_label = image_label
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.image_id[idx]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 读取图片标签
        label = self.image.label[idx]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • 实例化和加载器可以写作:
transform = {'train':T.Compose([T.Resize((224,224)),T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),T.ToTensor(),]),
             'valid':T.Compose([T.Resize((224,224)),T.ToTensor(),])}
             
# 实例化训练数据
train_data = CustomImageDataset(train_path, train_label, img_dir = './',transform = transform['train'])
# 实例化训练数据
valid_data = CustomImageDataset(valid_path, valid_label, img_dir = './',transform = transform['valid'])

# 训练集数据加载器
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
# 验证集数据加载器
valid_dataloader = DataLoader(valid_data, batch_size=2, shuffle=True, drop_last=False)

你可能感兴趣的:(pytorch,深度学习,python)