Pytorch---使用Pytorch实现多分类问题

一、代码中的数据集可以通过运行以下代码进行获取

train_ds = torchvision.datasets.MNIST(root=r'dataset', train=True, transform=ToTensor(), download=True)
test_ds = torchvision.datasets.MNIST(root=r'dataset', train=False, transform=ToTensor(), download=True)

二、代码运行环境

Pytorch-gpu==1.7.1
Python==3.7

三、数据集处理代码如下所示

import torchvision
from torchvision.transforms import ToTensor
import torch.utils.data
import matplotlib.pyplot as plt
import numpy as np


def make_dataset():
    train_ds = torchvision.datasets.MNIST(root=r'dataset', train=True, transform=ToTensor(), download=True)
    test_ds = torchvision.datasets.MNIST(root=r'dataset', train=False, transform=ToTensor(), download=True)
    train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=64, shuffle=True)
    test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=64)
    return train_dl, test_dl


if __name__ == '__main__':
    train, test = make_dataset()
    images, label = next(iter(train))
    plt.figure(figsize=(10, 3))
    for i, img in enumerate(images[:10]):
        np_img = img.numpy()
        np_img = np.squeeze(np_img)
        plt.subplot(1, 10, i + 1)
        plt.imshow(np_img)
        plt.axis('off')
        plt.title(str(label[i].numpy()))
    plt.show()

四、模型的构建代码如下所示

from torch import nn
import torch


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.liner_1 = nn.Linear(in_features=28 * 28, out_features=120)
        self.liner_2 = nn.Linear(in_features=120, out_features=84)
        self.liner_3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, input):
        x = input.view(-1, 28 * 28)
        x = torch.relu(self.liner_1(x))
        x = torch.relu(self.liner_2(x))
        logits = self.liner_3(x)
        return logits

五、模型的训练代码如下所示

import torch
from data_loader import make_dataset
from model_loader import Model
from torch import nn
import tqdm
import os

if __name__ == '__main__':
    # 进行数据的加载
    train_dl, test_dl = make_dataset()

    # 进行模型的加载
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = Model().to(device)

    # 定义相关的训练参数
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=opt, milestones=[25, 50, 75], gamma=0.1)
    epochs = 100

    for epoch in range(epochs):
        # 开始进行训练
        train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
        train_tqdm.set_description_str('Train_Epoch {:3d}'.format(epoch))
        model.train()
        for image, label in train_tqdm:
            image, label = image.to(device), label.to(device)
            pred = model(image)
            loss = loss_fn(pred, label)
            opt.zero_grad()
            loss.backward()
            opt.step()
            with torch.no_grad():
                train_tqdm.set_postfix_str('Train_Loss is {:.14f}'.format(loss_fn(pred, label).item()))
        train_tqdm.close()
        # 开始进行测试
        with torch.no_grad():
            test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
            test_tqdm.set_description_str('Test_Epoch {:3d}'.format(epoch))
            model.eval()
            for image, label in test_tqdm:
                image, label = image.to(device), label.to(device)
                pred = model(image)
                loss = loss_fn(pred, label)
                test_tqdm.set_postfix_str('Test_Loss is {:.14f}'.format(loss.item()))
            test_tqdm.close()
        # 进行动态学习率的调整
        scheduler.step()

    # 进行模型的保存
    if not os.path.exists('model_data'):
        os.mkdir('model_data')
    torch.save(model.state_dict(), r'model_data\model.pth')

六、模型的预测代码如下所示

from model_loader import Model
from data_loader import make_dataset
import torch
import matplotlib.pyplot as plt
import matplotlib

# 进行数据的加载
train_dl, test_dl = make_dataset()

# 进行模型的加载
model = Model()
model_state_dict = torch.load(r'model_data\model.pth')
model.load_state_dict(model_state_dict)
model.eval()

# 进行模型的预测
index = 5
image, label = next(iter(test_dl))
with torch.no_grad():
    pred = model(image)
    pred = torch.argmax(input=pred, dim=-1)
    show_image = torch.squeeze(image)
    matplotlib.rc("font", family='Microsoft YaHei')
    plt.imshow(show_image[index])
    plt.title('预测结果为:' + str(pred[index].numpy()) + ',标签结果为:' + str(label[index].numpy()))
    plt.axis('off')
    plt.savefig('result.png')
    plt.show()

七、代码的运行结果如下所示

Pytorch---使用Pytorch实现多分类问题_第1张图片

你可能感兴趣的:(Pytorch,pytorch,分类,深度学习)