PyTorch学习笔记-常用函数与数据加载

1. PyTorch常用函数

(1)路径相关的函数
假设我们数据集的目录结构如下:

PyTorch学习笔记-常用函数与数据加载_第1张图片

首先需要 import os,在 os 中常用的路径相关的函数有:

  • os.listdir(path):将 path 目录下的内容列成一个 list
  • os.path.join(path1, path2):拼接路径:path1\path2

例如:

import os

dir_path = 'dataset/hymenoptera_data/train/ants_image'
img_path_list = os.listdir(dir_path)
img_full_path = os.path.join(dir_path, img_path_list[0])
print(img_path_list)  # ['0013035.jpg', '1030023514_aad5c608f9.jpg', ...]
print(img_full_path)  # dataset/hymenoptera_data/train/ants_image\0013035.jpg

(2)辅助函数

  • dir():不带参数时,返回当前范围内的变量、方法和定义的类型列表;带参数时,返回参数的属性、方法列表。
  • help(func):查看函数 func 的使用说明。

例如:

import torch

print(dir(torch))  # ['AVG', 'AggregationType', ..., 'cuda', ...]
help(torch.cuda.is_available)  # Help on function is_available in module torch.cuda: is_available() -> bool...

2. 数据加载

PyTorch 数据集 (Dataset),数据读取和预处理是进行机器学习的首要操作,PyTorch 提供了很多方法来完成数据的读取和预处理。

(1)Dataset:torch.utils.data.Dataset 是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义 __len____getitem__ 这个两个函数即可,例如:

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir + '_image')
        self.img_path_list = os.listdir(self.path)

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        img_full_path = os.path.join(self.root_dir, self.label_dir + '_image', img_path)
        img = Image.open(img_full_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path_list)

root_dir = 'dataset/hymenoptera_data/train'
ants_label_dir = 'ants'

ants_data = MyData(root_dir, ants_label_dir)
img, label = ants_data[0]
print(img, label)
img.show()

通过上面的方式,可以定义我们需要的数据类,可以通过迭代的方式来获取每一个数据,但这样很难实现取 batch、shuffle 或者是多线程去读取数据。

(2)DataLoader:torch.utils.data.DataLoader() 构建可迭代的数据装载器,我们在训练的时候,每一个 for 循环,每一次 iteration,就是从 DataLoader 中获取一个 batch_size 大小的数据的。学习 DataLoader 之前需要先学一下 Transform:PyTorch学习笔记-Transform。

DataLoader 的参数很多,但我们常用的主要有以下几个:

  • dataset:Dataset 类,决定数据从哪读取以及如何读取。
  • bath_size:批大小。
  • num_works:是否多进程读取机制。
  • shuffle:每个 epoch 是否乱序。
  • drop_last:当样本数不能被 batch_size 整除时,是否舍弃最后一批数据。

要理解这个 drop_last,首先,得先理解 Epoch、Iteration 和 Batch_size 的概念:

  • Epoch:所有训练样本都已输入到模型中,称为一个 Epoch。
  • Iteration:一批样本输入到模型中,称为一个 Iteration。
  • Batch_size:一批样本的大小,决定一个 Epoch 有多少个 Iteration。

DataLoader 的作用就是构建一个数据装载器,根据我们提供的 batch_size 的大小,将数据样本分成一个个的 batch 去训练模型,而这个分的过程中需要把数据取到,这个就是借助 Dataset__getitem__ 方法。

例如:

你可能感兴趣的:(Artificial,Intelligence,pytorch,学习,深度学习,人工智能,python)