基于CNN的FashionMNIST分类

基于CNN的FashionMNIST分类

  • 1卷积神经网络算法简介
    • 1.1卷积层
    • 1.2池化层
    • 1.3全连接层
  • 2实验设置及结果分析
    • 2.1环境配置
    • 2.2数据集
    • 2.3模型搭建
    • 2.4模型训练及测试
    • 2.5精度曲线和损失曲线
    • 2.6精确率和召回率
    • 2.7混淆矩阵
  • 3总结与展望

1卷积神经网络算法简介

卷积神经网络是一种多层神经网络,擅长处理图像特别是大图像的相关机器学习问题。
卷积网络通过一系列方法,成功将数据量庞大的图像识别问题不断降维,最终使其能够被训练。
CNN最早由Yann LeCun提出并应用在手写字体识别上(MINST)。

CNN是一种人工神经网络,CNN的结构可以分为3层:

  1. 卷积层(Convolutional Layer) - 主要作用是提取特征。
  2. 池化层(Max Pooling Layer) - 主要作用是下采样(downsampling),却不会损坏识别结果。
  3. 全连接层(Fully Connected Layer) - 主要作用是分类。

1.1卷积层

卷积应用在图像上,可以理解为拿一个滤镜放在图像上,找出图像中的某些特征,而我们需要找到很多特征才能区分某一物体,所以我们会有很多滤镜,通过这些滤镜的组合,我们可以得出很多的特征。

例如我们用一组过滤器(Filter)来对图片过滤,过滤的过程就是求卷积的过程。假设我们的Filter的大小为3 * 3,我们从图片的左上角开始移动Filter,并且把每次矩阵相乘的结果记录下来。可以通过下面的过程来演示。
基于CNN的FashionMNIST分类_第1张图片

1.2池化层

池化(pool)即下采样(downsamples),目的是为了减少特征图。池化操作对每个深度切片独立,规模一般为 2*2,相对于卷积层进行卷积运算,池化层进行的运算一般采用最大池化(Max Pooling)即取4个点的最大值。

池化层的数据丢失并不会产生较大影响,因为我们每次保留的输出都是局部最显著的一个输出,而池化之后,最显著的特征并没丢失。我们只保留了认为最显著的特征,而把其他无用的信息丢掉,来减少运算。

基于CNN的FashionMNIST分类_第2张图片

1.3全连接层

全连接层的作用主要是进行分类。前面通过卷积和池化层得出的特征,在全连接层对这些总结好的特征做分类。
全连接层就是一个完全连接的神经网络,根据权重每个神经元反馈的比重不一样,最后通过调整权重和网络得到分类的结果。

2实验设置及结果分析

2.1环境配置

系统:window 10
GPU:Nvidia 1060 6G
Python:3.8
tensorflow:2.4.1
keras:2.4.3

2.2数据集

Fashion-MNIST 是 Zalando 文章图像的数据集,包括一组 60000 个示例和一组 10000 个示例的测试集。
每个示例都是 28x28 灰度图像,与 10 个类的标签相关联。

@article{DBLP:journals/corr/abs-1708-07747,
  author    = {Han Xiao and
               Kashif Rasul and
               Roland Vollgraf},
  title     = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
               Algorithms},
  journal   = {CoRR},
  volume    = {abs/1708.07747},
  year      = {2017},
  url       = {http://arxiv.org/abs/1708.07747},
  archivePrefix = {arXiv},
  eprint    = {1708.07747},
  timestamp = {Mon, 13 Aug 2018 16:47:27 +0200},
  biburl    = {https://dblp.org/rec/bib/journals/corr/abs-1708-07747},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
from keras import datasets
from keras.utils import to_categorical
(train_images, train_labels), (test_images,
                               test_labels) = datasets.fashion_mnist.load_data()
train_images = train_images.reshape([60000, 28, 28, 1]) / 255.0
test_images = test_images.reshape([10000, 28, 28, 1]) / 255.0
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

2.3模型搭建

from keras import models
from keras import layers
from keras.utils import plot_model
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________
plot_model(model,show_shapes=True)

基于CNN的FashionMNIST分类_第3张图片

2.4模型训练及测试

model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy', metrics=['accuracy'])
history = model.fit(train_images, train_labels, epochs=10,
                    batch_size=64, validation_data=(test_images, test_labels))
Epoch 1/10
938/938 [==============================] - 11s 9ms/step - loss: 0.7610 - accuracy: 0.7141 - val_loss: 0.3790 - val_accuracy: 0.8670
Epoch 2/10
938/938 [==============================] - 7s 8ms/step - loss: 0.3433 - accuracy: 0.8749 - val_loss: 0.3435 - val_accuracy: 0.8757
Epoch 3/10
938/938 [==============================] - 7s 8ms/step - loss: 0.2791 - accuracy: 0.8974 - val_loss: 0.2928 - val_accuracy: 0.8940
Epoch 4/10
938/938 [==============================] - 7s 8ms/step - loss: 0.2476 - accuracy: 0.9088 - val_loss: 0.2804 - val_accuracy: 0.8978
Epoch 5/10
938/938 [==============================] - 7s 8ms/step - loss: 0.2227 - accuracy: 0.9186 - val_loss: 0.3001 - val_accuracy: 0.8886
Epoch 6/10
938/938 [==============================] - 7s 8ms/step - loss: 0.2041 - accuracy: 0.9265 - val_loss: 0.2849 - val_accuracy: 0.8970
Epoch 7/10
938/938 [==============================] - 7s 8ms/step - loss: 0.1847 - accuracy: 0.9307 - val_loss: 0.2833 - val_accuracy: 0.9053
Epoch 8/10
938/938 [==============================] - 7s 8ms/step - loss: 0.1688 - accuracy: 0.9382 - val_loss: 0.2675 - val_accuracy: 0.9071
Epoch 9/10
938/938 [==============================] - 7s 8ms/step - loss: 0.1565 - accuracy: 0.9433 - val_loss: 0.3248 - val_accuracy: 0.8952
Epoch 10/10
938/938 [==============================] - 7s 8ms/step - loss: 0.1457 - accuracy: 0.9462 - val_loss: 0.2828 - val_accuracy: 0.9107

准确率为91%左右

2.5精度曲线和损失曲线

import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc)+1)

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

基于CNN的FashionMNIST分类_第4张图片

基于CNN的FashionMNIST分类_第5张图片

2.6精确率和召回率

精确率是针对我们预测结果而言的,它表示的是预测为正的样本中有多少是真正的正样本。
计算公式如下:

在这里插入图片描述

召回率是针对我们原来的样本而言的,它表示的是样本中的正例有多少被预测正确了。
计算公式如下:

在这里插入图片描述

y_predict = model.predict(test_images).argmax(axis=1)
y_true = test_labels.argmax(axis=1)
from sklearn.metrics import precision_score
print("Precision:",precision_score(y_true, y_predict, average='micro'))
Precision: 0.9107
from sklearn.metrics import recall_score
print("Recall:",recall_score(y_true, y_predict, average='micro'))
Recall: 0.9107

2.7混淆矩阵

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(8, 8))
matrix = confusion_matrix(y_true, y_predict)
print("confusion_matrix\n",matrix)
sns.heatmap(matrix, annot=True, cmap="Blues", fmt='g')
confusion_matrix
 [[889   0  24  19   4   1  57   0   6   0]
 [  0 984   1  11   1   0   1   0   2   0]
 [ 16   1 863   7  60   0  53   0   0   0]
 [ 17  14   9 905  28   0  26   0   1   0]
 [  0   0  39  13 891   0  55   0   2   0]
 [  0   0   0   0   0 988   0   7   0   5]
 [152   2  57  22  74   0 682   0  11   0]
 [  0   0   0   0   0  15   0 973   0  12]
 [  5   1   4   3   4   2   2   1 978   0]
 [  0   0   0   0   0   5   1  40   0 954]]    

基于CNN的FashionMNIST分类_第6张图片

3总结与展望

本次实验针对FashinMNIST数据集,采用卷积神经网络进行测试,得到最终准确率为91%左右,实验效果较好。
主要原因在于CNN对于图像特征的提取非常有效,每个filter都可以被看作是特征标识符,这里的特征指的是直线边缘、曲线、颜色等,它们是每个图像的都具备的最简单的特征。通过卷积计算,激活反应强的表明图像中有存载着filterd标识的特征,这样就对图像的特征有过滤的作用。
在本次实验基础上,可以引入图像增强技术,对训练集进行一些处理,可能可以进一步提高分类的准确度。

你可能感兴趣的:(深度学习,卷积,深度学习,计算机视觉,神经网络,tensorflow)