深入浅出Pytorch系列(4):实战--FashionMNIST时装分类

时装分类的任务

FashionMNIST数据集中包含已经预先划分好的训练集和测试集,其中训练集共60,000张图像,测试集共10,000张图像。每张图像均为单通道黑白图像,大小为32*32pixel,分属10个类别。

首先导入必要的包

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

配置训练环境和超参数

配置GPU
配置超参数如:batch_size, num_workers, learning rate, 以及总的epochs

数据读入和加载

这里同时展示两种方式:

下载并使用PyTorch提供的内置数据集
从网站下载以csv格式存储的数据,读入并转成预期的格式
第一种数据读入方式只适用于常见的数据集,如MNIST,CIFAR10等,PyTorch官方提供了数据下载。这种方式往往适用于快速测试方法(比如测试下某个idea在MNIST数据集上是否有效)
第二种数据读入方式需要自己构建Dataset,这对于PyTorch应用于自己的工作中十分重要
同时,还需要对数据进行必要的变换,比如说需要将图片统一为一致的大小,以便后续能够输入网络训练;需要将数据格式转为Tensor类,等等。

模型设计

通过nn.module以及nn.sequential对网络结构进行搭建

设定损失函数

多分类问题一般使用torch.nn模块自带的nn.CrossEntropy损失

设定优化器

Adam优化器较为常用

训练和测试(验证)

常规做法是将训练和测试各自封装成函数,方便后续调用

两者的主要区别:

  • 模型状态设置
  • 是否需要初始化优化器
  • 是否需要将loss传回到网络
  • 是否需要每步更新optimizer

此外,对于测试或验证过程,可以计算分类准确率

学习链接:
https://github.com/datawhalechina/thorough-pytorch

你可能感兴趣的:(pytorch,分类,深度学习)