pytorch学习(十五)—自定义CNN网络训练FashionMNIST数据集

前言

入门深度学习,一般都是跑最经典的MNIST+LeNet-5, LeNet-5网络结构简单,MNIST数据集也不是很大,对于初学者来说很方便和友好。作为进阶,熟悉Pytorch基本用法之后,跃跃欲试,想自己手写一个CNN网络,在一个数据集上进行训练和测试。

FashionMNIST数据集作为进阶的练习很不错,本实验将基于FashionMNIST数据集从头到尾训练测试一个CNN网络。


FashionMNIST数据集

简介

https://github.com/zalandoresearch/fashion-mnist

image.png

Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. We intend Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.

FashionMNIST数据特点:

  • 60,000个训练样本+10,000个测试样本
  • 样本图像为灰度,28x28
  • 10个类别

Labels
Each training and test example is assigned to one of the following labels:

Label Description
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
image.png
image.png

Why we made Fashion-MNIST

Why we made Fashion-MNIST

The original MNIST dataset contains a lot of handwritten digits. Members of the AI/ML/Data Science community love this dataset and use it as a benchmark to validate their algorithms. In fact, MNIST is often the first dataset researchers try. "If it doesn't work on MNIST, it won't work at all", they said. "Well, if it does work on MNIST, it may still fail on others."

To Serious Machine Learning Researchers

Seriously, we are talking about replacing MNIST. Here are some good reasons:

  • MNIST is too easy. Convolutional nets can achieve 99.7% on MNIST. Classic machine learning algorithms can also achieve 97% easily. Check out our side-by-side benchmark for Fashion-MNIST vs. MNIST, and read "Most pairs of MNIST digits can be distinguished pretty well by just one pixel."
  • MNIST is overused. In this April 2017 Twitter thread, Google Brain research scientist and deep learning expert Ian Goodfellow calls for people to move away from MNIST.
  • MNIST can not represent modern CV tasks, as noted in this April 2017 Twitter thread, deep learning expert/Keras author François Chollet.

实验

获取数据集

可以自己在网站上下载数据,pytorch提供了更好的方式,直接使用torchvision.datasets中的API,自动下载数据。

由于采用CPU模式,batch size 设置为4, 使用GPU模式,显存足够大的话可以将batch size设置大一些,使用英伟达1080 Ti, 本人设置为batch size = 16

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils


# https://blog.csdn.net/weixin_41278720/article/details/80778640

# ---------------------------数据集-------------------------------------

data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])

train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset  = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)

# 随机显示一个batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()

下载完成之后的数据集:

image.png
image.png
image.png

定义一个CNN网络

定义网络的一般格式:

  • 继承 nn.Module
  • __init()__中定义网络的层
  • 重写(override)父类的抽象方法forward()

区别与之前定义LeNet-5, 此次采用nn.Sequential, 传入一个有序字典OrderedDict。加入了BatchNorm, Dropout层, 并且第一个卷积之后并没有进行池化,这样可以保留更多的信息进入下一层。

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


class Net(nn.Module):
    '''

    自定义的CNN网络,3个卷积层,包含batch norm。2个pool,
    3个全连接层,包含Dropout
    输入:28x28x1s
    '''
    def __init__(self):
        super(Net, self).__init__()
        self.feature = nn.Sequential(
            OrderedDict(
                [
                    # 28x28x1
                    ('conv1', nn.Conv2d(in_channels=1,
                                        out_channels=32,
                                        kernel_size=5,
                                        stride=1,
                                        padding=2)),

                    ('relu1', nn.ReLU()),
                    ('bn1', nn.BatchNorm2d(num_features=32)),

                    # 28x28x32
                    ('conv2', nn.Conv2d(in_channels=32,
                                        out_channels=64,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)),

                    ('relu2', nn.ReLU()),
                    ('bn2', nn.BatchNorm2d(num_features=64)),
                    ('pool1', nn.MaxPool2d(kernel_size=2)),

                    # 14x14x64
                    ('conv3', nn.Conv2d(in_channels=64,
                                        out_channels=128,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)),

                    ('relu3', nn.ReLU()),
                    ('bn3', nn.BatchNorm2d(num_features=128)),
                    ('pool2', nn.MaxPool2d(kernel_size=2)),

                    # 7x7x128
                    ('conv4', nn.Conv2d(in_channels=128,
                                        out_channels=64,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)),

                    ('relu4', nn.ReLU()),
                    ('bn4', nn.BatchNorm2d(num_features=64)),
                    ('pool3', nn.MaxPool2d(kernel_size=2)),

                    # out 3x3x64

                ]
            )
        )

        self.classifier = nn.Sequential(


            OrderedDict(
                [
                    ('fc1', nn.Linear(in_features=3 * 3 * 64,
                                      out_features=128)),
                    ('dropout1', nn.Dropout2d(p=0.5)),

                    ('fc2', nn.Linear(in_features=128,
                                      out_features=64)),

                    ('dropout2', nn.Dropout2d(p=0.6)),

                    ('fc3', nn.Linear(in_features=64, out_features=10))
                ]
            )

        )

    def forward(self, x):
        out = self.feature(x)
        out = out.view(-1, 64 * 3 *3)
        out = self.classifier(out)
        return out

训练CNN网络

  • epoch num设置为100, GPU跑的话其实很快就跑完了
  • 每迭代100次,进行一次测试,统计Accuarcy, running loss打印一次,并且保存的log文本中,方便后序的分析
  • 训练时候,调用net.train() 将模型设置为train()模式, 测试时候调用net.eval()将模型设置为eval()模式。 否则结果不正确,因为网络中使用了BatchNorm和Dropout,两者在eval(), train()模式下有所差异,具体看pytorch文档。
  • 训练完成之后,保存模型。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils


# https://blog.csdn.net/weixin_41278720/article/details/80778640

# ---------------------------数据集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])

train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset  = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)

# 随机显示一个batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()

# -------------------------定义网络,参数设置--------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.Net()
print(net)
net = net.to(device)

loss_fc = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# -----------------------------训练-----------------------------------------
file_runing_loss = open('./log/running_loss.txt', 'w')
file_test_accuarcy = open('./log/test_accuracy.txt', 'w')

epoch_num = 100
for epoch in range(epoch_num):
    running_loss = 0.0
    accuracy = 0.0
    scheduler.step()
    for i, sample_batch in enumerate(train_dataloader):

        inputs = sample_batch[0]
        labels = sample_batch[1]

        inputs = inputs.to(device)
        labels = labels.to(device)

        net.train()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_fc(outputs, labels)
        loss.backward()
        optimizer.step()

        print(i, loss.item())

        # 统计数据,loss,accuracy
        running_loss += loss.item()
        if i % 20 == 19:
            correct = 0
            total = 0
            net.eval()
            for inputs, labels in val_dataloader:
                outputs = net(inputs)
                _, prediction = torch.max(outputs, 1)
                correct += ((prediction == labels).sum()).item()
                total += labels.size(0)

            accuracy = correct / total
            print('[{},{}] running loss = {:.5f} acc = {:.5f}'.format(epoch + 1, i+1, running_loss / 20, accuracy))
            file_runing_loss.write(str(running_loss / 20)+'\n')
            file_test_accuarcy.write(str(accuracy)+'\n')
            running_loss = 0.0

print('\n train finish')
torch.save(net.state_dict(), './model/model_100_epoch.pth')


训练结果

image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png

训练的结果还不错,Accuracy最高达到93%左右。


测试网络

输入1个batch, batch=4,加载训练好的模型。
注意: 之前模型的训练是在GPU上训练的, 模型保存的存储布局是按照GPU模式的, 在CPU模式下调用GPU训练的模型时候需要添加:
net.load_dict(torch.load('xxx.pth', map_loaction='cpu'))


import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils


data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])


test_dataset  = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)

test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, shuffle=False)

plt.figure()
utils.imshow_batch(next(iter(test_dataloader)))

net = net.Net()
net.load_state_dict(torch.load(f='./model/model_100_epoch.pth', map_location='cpu'))
print(net)

images, labels = next(iter(test_dataloader))
outputs = net(images)
_, prediction = torch.max(outputs, 1)
print('label:', labels)
print('prdeiction:', prediction)

plt.show()

image.png
image.png

完整工程

  • 网络定义
    net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


class Net(nn.Module):
    '''

    自定义的CNN网络,3个卷积层,包含batch norm。2个pool,
    3个全连接层,包含Dropout
    输入:28x28x1s
    '''
    def __init__(self):
        super(Net, self).__init__()
        self.feature = nn.Sequential(
            OrderedDict(
                [
                    # 28x28x1
                    ('conv1', nn.Conv2d(in_channels=1,
                                        out_channels=32,
                                        kernel_size=5,
                                        stride=1,
                                        padding=2)),

                    ('relu1', nn.ReLU()),
                    ('bn1', nn.BatchNorm2d(num_features=32)),

                    # 28x28x32
                    ('conv2', nn.Conv2d(in_channels=32,
                                        out_channels=64,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)),

                    ('relu2', nn.ReLU()),
                    ('bn2', nn.BatchNorm2d(num_features=64)),
                    ('pool1', nn.MaxPool2d(kernel_size=2)),

                    # 14x14x64
                    ('conv3', nn.Conv2d(in_channels=64,
                                        out_channels=128,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)),

                    ('relu3', nn.ReLU()),
                    ('bn3', nn.BatchNorm2d(num_features=128)),
                    ('pool2', nn.MaxPool2d(kernel_size=2)),

                    # 7x7x128
                    ('conv4', nn.Conv2d(in_channels=128,
                                        out_channels=64,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)),

                    ('relu4', nn.ReLU()),
                    ('bn4', nn.BatchNorm2d(num_features=64)),
                    ('pool3', nn.MaxPool2d(kernel_size=2)),

                    # out 3x3x64

                ]
            )
        )

        self.classifier = nn.Sequential(


            OrderedDict(
                [
                    ('fc1', nn.Linear(in_features=3 * 3 * 64,
                                      out_features=128)),
                    ('dropout1', nn.Dropout2d(p=0.5)),

                    ('fc2', nn.Linear(in_features=128,
                                      out_features=64)),

                    ('dropout2', nn.Dropout2d(p=0.6)),

                    ('fc3', nn.Linear(in_features=64, out_features=10))
                ]
            )

        )

    def forward(self, x):
        out = self.feature(x)
        out = out.view(-1, 64 * 3 *3)
        out = self.classifier(out)
        return out

  • 训练
    train.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils


# https://blog.csdn.net/weixin_41278720/article/details/80778640

# ---------------------------数据集-------------------------------------
data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])

train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, transform=tranform)
val_dataset  = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, num_workers=4, shuffle=False)

# 随机显示一个batch
plt.figure()
utils.imshow_batch(next(iter(train_dataloader)))
plt.show()

# -------------------------定义网络,参数设置--------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.Net()
print(net)
net = net.to(device)

loss_fc = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# -----------------------------训练-----------------------------------------
file_runing_loss = open('./log/running_loss.txt', 'w')
file_test_accuarcy = open('./log/test_accuracy.txt', 'w')

epoch_num = 100
for epoch in range(epoch_num):
    running_loss = 0.0
    accuracy = 0.0
    scheduler.step()
    for i, sample_batch in enumerate(train_dataloader):

        inputs = sample_batch[0]
        labels = sample_batch[1]

        inputs = inputs.to(device)
        labels = labels.to(device)

        net.train()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_fc(outputs, labels)
        loss.backward()
        optimizer.step()

        print(i, loss.item())

        # 统计数据,loss,accuracy
        running_loss += loss.item()
        if i % 20 == 19:
            correct = 0
            total = 0
            net.eval()
            for inputs, labels in val_dataloader:
                outputs = net(inputs)
                _, prediction = torch.max(outputs, 1)
                correct += ((prediction == labels).sum()).item()
                total += labels.size(0)

            accuracy = correct / total
            print('[{},{}] running loss = {:.5f} acc = {:.5f}'.format(epoch + 1, i+1, running_loss / 20, accuracy))
            file_runing_loss.write(str(running_loss / 20)+'\n')
            file_test_accuarcy.write(str(accuracy)+'\n')
            running_loss = 0.0

print('\n train finish')
torch.save(net.state_dict(), './model/model_100_epoch.pth')


  • 可视化工具
    utils.py
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np


def imshow_batch(sample_batch):
    images = sample_batch[0]
    labels = sample_batch[1]
    images = make_grid(images, nrow=4, pad_value=255)
    # 1,2, 0 
    images_transformed = np.transpose(images.numpy(), (1, 2, 0))
    plt.imshow(images_transformed)
    plt.axis('off')
    labels = labels.numpy()
    plt.title(labels)


  • 测试
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as tranforms
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import net
import utils


data_dir = '/media/weipenghui/Extra/FashionMNIST'
tranform = tranforms.Compose([tranforms.ToTensor()])


test_dataset  = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)

test_dataloader = DataLoader(dataset=test_dataset, batch_size=4, num_workers=4, shuffle=False)

plt.figure()
utils.imshow_batch(next(iter(test_dataloader)))

net = net.Net()
net.load_state_dict(torch.load(f='./model/model_100_epoch.pth', map_location='cpu'))
print(net)

images, labels = next(iter(test_dataloader))
outputs = net(images)
_, prediction = torch.max(outputs, 1)
print('label:', labels)
print('prdeiction:', prediction)

plt.show()



End

你可能感兴趣的:(pytorch学习(十五)—自定义CNN网络训练FashionMNIST数据集)