Pytorch基本操作(2)——Dataset以及Dataloader

1 简介

在学习李沐在B站发布的《动手学深度学习》PyTorch版本教学视频中发现在操作使用PyTorch方面有许多地方看不懂,往往只是“动手”了,没有动脑。所以打算趁着寒假的时间好好恶补、整理一下PyTorch的操作,以便跟上课程。

学习资源:

  • B站up主:我是土堆的视频:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
  • PyTorch中文手册:(pytorch handbook)
  • Datawhale开源内容:深入浅出PyTorch(thorough-pytorch)

2 Dataset以及Dataloader

Dataset 以及 Dataloader 是Pytorch中读取数据需要用到的两个重要的类

  • Dataset :提供一种方式去获取数据及其lable,需要我们自己去写。
  • Dataloader :为后面的网络提供不同的数据形式

常见的图片数据集有两种形式:

  • label直接标在文件夹上
  • label另外放在另一个文件夹对应的txt文件中(OCR)
  • label写在图片的名称上

2.1 Dataset

功能:

  1. 如何获取每一个数据及其label。
  2. 告诉我们总共有多少的数据。

我们可以运行 from torch.utils.data import Dataset (其中 utils 有实用工具的意思,理解为工具区)来导入Dataset这个类,同时也可以使用 Dataset??help(Dataset) 来看如何使用Dataset

import torch
from torch.utils.data import Dataset
Dataset??
help(Dataset)
Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.
 |  
 |  .. note::
 |    :class:`~torch.utils.data.DataLoader` by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.
 |  
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |  
 |  __getitem__(self, index) -> +T_co
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |  
 |  __parameters__ = (+T_co,)
 |  
 |  ----------------------------------------------------------------------
 |  Class methods inherited from typing.Generic:
 |  
 |  __class_getitem__(params) from builtins.type
 |  
 |  __init_subclass__(*args, **kwargs) from builtins.type
 |      This method is called when a class is subclassed.
 |      
 |      The default implementation does nothing. It may be
 |      overridden to extend subclasses.

简言之:Dataset是个抽象类,所有的子类都要重写__getitem__方法获取label,也可以重写__len__方法获取长度

2.2 Dataset类代码实战

其中用到的两个库:

  1. PIL中的Image
    PIL(Python Imaging Library):是Python的图像处理库

    • img = Image.open(image_path) 读取对应路径的图片为一个变量
    • img.show() 使用系统默认图片打开方式打开此图片
  2. os:operating system

    • os.path.join(root_dir, label_dir):将两个路径连起来,这个函数会根据操作系统自动调整路径的语法;其中dir = directory目录
    • os.listdir(dir_path):顾名思义,就是把括号里面的目录路径中的“所有文件的路径”生成一个列表,可以用类似a[0]的语句取出 对应图片的路径
from torch.utils.data import Dataset
from PIL import Image
import os

2.2.1 创建类

图片数据集格式为:label直接标在文件夹上

class MyData(Dataset):
    
    def __init__(self, root_dir, label_dir):
        # 初始化函数,为后面的getitem和next方法提供所需要的量
        self.root_dir = root_dir # 根目录的路径;self可以理解为当前类内部的一个全局变量
        self.label_dir = label_dir # label目录的路径,因为下一行要合起来,并且label名是文件夹名,所以这里的label_dir可以直接取对应的label,如:"ants"
        self.path = os.path.join(root_dir, label_dir) # 将两个路径连起来
        self.img_path = os.listdir(self.path) # 获取该路径下所有文件的路径列表
        
    def __getitem__(self, idx):
        """获取数据集中的每一个图片,输入索引,得到对应的图片"""
        # idx是index索引的缩写
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.path, img_name)
        image = Image.open(img_item_path)
        label = self.label_dir
        return image, label
    
    def __len__(self):
        return len(self.img_path)

2.2.2 创建个实例看看

对应的蜜蜂蚂蚁图片识别数据集见小土堆B站视频简介

root_dir = r"F:\Data and code\data\蚂蚁蜜蜂数据\hymenoptera_data\hymenoptera_data\train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

取一个试试看到底能不能取出来单独数据

img, label = ants_dataset[0]
# img.show()
img

Pytorch基本操作(2)——Dataset以及Dataloader_第1张图片

2.2.3 将数据集格式转化为txt存放label格式

"""下面的代码是将本节中所提到的格式转化成使用txt文档存放label的格式"""
"""这种格式就是新建一个ant_label文件夹,其中放的都是.txt文件。每一个文件的名字都是对应图片的名字,文件的内容则是对应的label"""
'''说实话还没看'''

root_dir = r"F:\Data and code\data\蚂蚁蜜蜂数据\hymenoptera_data\hymenoptera_data\train"
target_dir = "ants_image"
img_path_list = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = "ants_label"
for i in img_path_list:
    file_name = i.split('.jpg')[0]
    with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
        f.write(label)

你可能感兴趣的:(pytorch,深度学习,pytorch,深度学习,神经网络)