注:安装任何包请使用conda install xxx
命令
更推荐使用docker方式搭建自己的开发环境
>>> import tensorflow
>>> import sonnet
>>> import torch
>>> import keras
>>> import mxnet
>>> import cntk
>>> import chainer
>>> import theano
>>> import lasagne
>>> import caffe
>>> import caffe2
Fashion-MNIST是一个替代MNIST手写数字集的图像数据集。 它是由Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自10种类别的共7万个不同商品的正面图片。Fashion-MNIST的大小、格式和训练集/测试集划分与原始的MNIST完全一致。60000/10000的训练测试数据划分,28x28的灰度图片。下载地址
kaggle链接
from keras.datasets import fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# 首先看一下数据的形状
print(train_images.shape)
print(test_images.shape)
#输出结果
(60000, 28, 28)
(10000, 28, 28)
可以看到训练数据是60000张28*28的图片,测试数据是10000张28*28的图片。
我们来看一下图片上都是什么数据:
import matplotlib.pyplot as plt
plt.imshow(train_images[0])
plt.savefig("train_images_0.png")
plt.show()
显示的结果:
之后将数据做reshape使数据适合训练,并将数据缩放到0-1之间。
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
对标签做one-hot
编码:
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)
将所有操作整合为一个函数:
def load_data():
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)
return (train_images, train_labels), (test_images, test_labels)
第一个模型使用使用三个卷积+pooling操作接两个全链接层。
model summary如下:
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 28, 28, 16) 80
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 16) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 14, 14, 32) 2080
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 7, 7, 32) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 7, 7, 64) 8256
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 3, 3, 64) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 3, 3, 64) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 576) 0
_________________________________________________________________
dense_1 (Dense) (None, 500) 288500
_________________________________________________________________
dropout_2 (Dropout) (None, 500) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 5010
=================================================================
Total params: 303,926
Trainable params: 303,926
Non-trainable params: 0
_________________________________________________________________
模型结构如图所示:
最终在测试集上的准确率为:87%
第二个模型比第一个模型更简单,使用了一个卷积+pooling接两个全连接层。
model summary如下:
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 28, 28, 32) 320
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 6272) 0
_________________________________________________________________
dense_1 (Dense) (None, 5128) 32167944
_________________________________________________________________
dense_2 (Dense) (None, 10) 51290
=================================================================
Total params: 32,219,554
Trainable params: 32,219,554
Non-trainable params: 0
_________________________________________________________________
模型结构如图所示:
最终在测试集上的准确率为:91.2%
并将预测结果进行显示,图片上名称为红色的为分类错误的图片: