PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)

来自 B 站刘二大人的《PyTorch深度学习实践》P8 的学习笔记

上一篇 处理多维特征的输入 的时候我们提到了 Mini-Batch,也就是训练时,一个 epoch 只遍历 N N N 个 samples(行), N ≤ 总 行 数 N \le 总行数 N,要实现这个,就要用到 PyTorch 强大的数据加载工具类:Dataset、DataLoader。

  • 一次性输入全部的 Batch 的优势在于计算性能,而且更容易用于并行计算;

  • 一次性输入一个样本的优势在于数据的随机性使得网络优化更容易走出鞍点,更易于训练收敛;

  • 我们折中使用 Mini-Batch,集成计算性能和优化性能。

Mini-Batch

使用 Mini-Batch 之后,训练的代码需要两层 for 循环:

  • 一个 Epoch:外层 for 循环一次,遍历完所有的样本(所有的样本都前馈和反向传播了一次)
  • 一个 Mini-Batch:内层 for 循环一次,前馈和反向传播一个 Batch-Size 的样本(一次性输入 batch-size 数量的样本到 model() 里)
  • 一个 Iteration:一次迭代一个 Mini-Batch

PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第1张图片

Dataset、DataLoader

DataLoader:batch_size=2,shuffle=True

shuffle 的意义在于打乱数据,提供随机性。
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第2张图片
Dataset 是一个抽象类,不能直接实例化,我们要自定义自己的数据类并继承 Dataset。

由于我们的数据类的实例要作为参数放到 DataLoader 类里面返回一个数据生成器,所以继承了 Dataset 的数据类要至少实现下面三个方法:
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第3张图片
这些双下划线方法都是 Python 自带的魔法方法,在自定义的类里面实现它们之后,类的实例就会具备相应的取数功能:

  • 例如定义了 def __getitem__(self, index),类的实例 dataset 就可以通过数组下标 dataset[index] 的方式调用 __getitem__(self, index)
  • 例如定义了 def __len__(self),类的实例 dataset 就可以通过 Python 内置方法 len(dataset) 来调用 __len__(self),并希望这个函数返回数据长度。
    PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第4张图片
    把我们自定义的数据类实例化之后,传参到 DataLoader() 中,得到
    PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第5张图片
    在 windows 下,训练的循环需要在 if __name__ == "__main__": 里面,否则不能设置 num_worker 参数来使用多线程加载数据。
    PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第6张图片

Using DataLoader

每一次迭代 train_loader 得到的是 batch_size 大小的样本矩阵:
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第7张图片
train_loader 的返回值是调用 __getitem__() 的结果
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第8张图片
神经网络的建模步骤从构建 Mini-Batch 数据加载器开始
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第9张图片

torchvision.datasets

torchvision.datasets 这个包里面有很多常用的数据集,这些 dataset 都继承自 PyTorch 的 Dataset 类,并实现了 __getitem____len__ 方法,因此可以直接实例化之后传入 DataLoader 来获取 data_loader。
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第10张图片
例如:指定的数据集位置如果没有相应的数据集,它会下载,如果有,则不会再下载。
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第11张图片

完整代码

  • diabetes 数据集
import copy

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset


class DiabetesDataset(Dataset):
    def __init__(self):
        xy = np.loadtxt("../datasets/diabetes/diabetes.csv.gz", delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        """很明显 Index 是行坐标"""
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


dataset = DiabetesDataset()
data_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
print(len(dataset))


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(8, 6)
        self.linear2 = nn.Linear(6, 4)
        self.linear3 = nn.Linear(4, 1)
        self.activate = nn.Sigmoid()

    def forward(self, x):
        x = self.activate(self.linear1(x))
        x = self.activate(self.linear2(x))
        x = self.activate(self.linear3(x))
        return x


model = Model()

criterion = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.003)  # SGD 效果不好,那效果看起来像没有优化一样

if __name__ == '__main__':
    for epoch in range(50):
        TP = 0  # 预测正确的总个数
        loss_lst = []
        for i, (x, y) in enumerate(data_loader):
            y_pred = model(x)
            loss = criterion(y_pred, y)

            loss_lst.append(loss.item())
            # 以下是计算一个 Mini-Batch 的精确度
            y_hat = copy.copy(y_pred.data.numpy())
            y_hat[y_hat >= 0.5] = 1.0
            y_hat[y_hat < 0.5] = 0.0
            TP += np.sum(y.data.numpy().flatten() == y_hat.flatten())
            # 也可以下面这样计算,后面TP要转成numpy()才能计算acc
            # y_hat = torch.round(y_pred.data)  # 取每个元素的最接近的整数,注:0.5接近0
            # TP += (y.data.flatten() == y_hat.flatten()).sum()  # 这里不能用 numpy.sum(), torch.sum()可以

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        acc = TP / len(dataset)
        print("epoch:", epoch, "loss:", np.mean(loss_lst), "acc:", acc, "TP:", TP)
  • MNIST 数据集
import os
import copy

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt


trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_set = datasets.MNIST(root="../datasets/mnist",
                           train=True,
                           transform=trans,  # 原始是 PIL Image 格式
                           download=True)
test_set = datasets.MNIST(root="../datasets/mnist",
                          train=False,
                          transform=trans,
                          download=True)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=True)

# examples = enumerate(train_loader)
# _, (x, y) = next(examples)
# print(x.shape, y.shape)


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(28*28, 60)  # MNIST 每个图像大小为 28*28
        self.linear2 = nn.Linear(60, 10)
        self.activate = nn.ReLU()

    def forward(self, x):
        x = self.activate(self.linear1(x))
        x = F.softmax(self.linear2(x))
        return x


model = Model()


def train(model, train_loader, save_dst="./models/acc_0.96.pth"):
    global acc

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    for epoch in range(5):

        TP = 0
        loss_lst = []
        for i, (imgs, labels) in enumerate(train_loader):
            x = imgs.reshape(-1, 28*28)  # 若不reshape,imgs矩阵默认为[32*28, 28]
            y_pred = model(x)
            # print("x:", x.shape, "y:", labels.shape, "y_pred:", y_pred.shape)

            loss = criterion(y_pred, labels)

            y_hat = copy.copy(y_pred)
            TP += torch.sum(labels.flatten() == torch.argmax(y_hat, dim=1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        acc = TP.data.numpy() / len(train_set)
        print("epoch:", epoch, "loss:", np.mean(loss_lst), "acc:", round(acc, 3), "TP:", TP)

    # 保存模型
    torch.save(model.state_dict(), os.path.join(save_dst, f"acc_{
       round(acc, 2)}.pth"))


def test(model, test_loader, load_dst="./models/acc_0.96.pth"):
    TP = 0
    model.load_state_dict(torch.load(load_dst))

    for i, (imgs, labels) in enumerate(test_loader):
        x = imgs.reshape(-1, 28*28)
        with torch.no_grad():
            y_pred = model(x)
        # print("x:", x.shape, "y:", labels.shape, "y_pred:", y_pred.shape)

        y_hat = copy.copy(y_pred)
        TP += torch.sum(labels.flatten() == torch.argmax(y_hat, dim=1))
    acc = TP.data.numpy() / len(test_set)
    print("acc:", round(acc, 4), "TP:", TP)


def draw(model, test_loader, load_dst="./models/acc_0.96.pth"):
    model.load_state_dict(torch.load(load_dst))

    examples = enumerate(test_loader)
    _, (imgs, labels) = next(examples)
    x = imgs.reshape(-1, 28 * 28)

    with torch.no_grad():
        y_pred = model(x)

    for i in range(6):
    """选取mini-batch中32个图像的前6个"""
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(imgs[i][0], cmap='gray', interpolation='none')
        plt.title("Prediction: {}".format(
            y_pred.data.max(1, keepdim=True)[1][i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


if __name__ == '__main__':
    train(model, train_loader)
    # test(model, test_loader, load_dst="./models/acc_0.96.pth")
    draw(model, test_loader, load_dst="./models/acc_0.96.pth")

预测结果如下:
PyTorch 入门与实践(三)加载数据集(Dataset、DataLoader)_第12张图片

绘图代码参考:Pytorch+torchvisio MNIST手写数字识别

课后练习

  • 在 Kaggle 获取 Titanic dataset:https://www.kaggle.com/c/titanic/data
  • 构建一个分类神经网络,使用 DataLoader 作为数据加载器,训练并得出结果,按照比赛的说明将结果格式化后提交。

你可能感兴趣的:(PyTorch,pytorch,DataLoader)