# 导入相应包
import os
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
plt.rcParams["font.sans-serif"] = "Microsoft YaHei"
plt.rcParams["axes.unicode_minus"] = False
#从 TensorFlow中导入和加载Fashion MNIST数据
fashion_mnist = keras.datasets.fashion_mnist.load_data()
(train_images, train_labels), (test_images, test_labels) = fashion_mnist
为了后期测试,这里利用OpenCV保存原尺寸的照片(训练集中的10张图)
# 利用OpenCV保存图片
for img in range(10):
cv.imwrite("{}.jpg".format(img), train_images[img])
# 获取当前工作路径查看
print("请复制路径自行查看:", os.getcwd())
设置中文名称便于对图像及类别标签的理解
标签 | 类别名称 | 中文名称 |
---|---|---|
0 | T-shirt/top | T恤 |
1 | Trouser | 裤子 |
2 | Pullover | 套衫 |
3 | Dress | 裙子 |
4 | Coat | 外套 |
5 | Sandal | 凉鞋 |
6 | Shirt | 汗衫 |
7 | Sneaker | 运动鞋 |
8 | Bag | 包 |
9 | Ankle boot | 短靴 |
# 设置中文名称
names = ["0-T恤", "1-裤子", "2-套衫", "3-裙子", "4-外套", "5-凉鞋", "6-汗衫", "7-运动鞋", "8-包", "9-短靴"]
for i in range(10):
plt.figure()
plt.title(names[test_labels[i]], fontsize=12)
plt.imshow(test_images[i])
plt.colorbar()
plt.show()
#训练模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10)])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_images, train_labels, batch_size=90, epochs=10)
# 准确率评估
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
#进行预测
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
predictions = probability_model.predict(test_images)
accuracy: 0.87841
# 模型保存
model.save('fashion.h5')
# 模型调用
new_model = keras.models.load_model('./fashion.h5')
# 查看网络结构
new_model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) (None, 784) 0
_________________________________________________________________
dense (Dense) (None, 128) 100480
_________________________________________________________________
dense_1 (Dense) (None, 10) 1290
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
# 从10张图中随机抽取6张进行预测
indexs = list(np.random.choice(10, 6, replace=False))
# 进行预测
plt.figure(figsize=(10, 8))
for img in range(6):
imgs = cv.imread("./{}.jpg".format(indexs[img]))
img_grays = cv.cvtColor(imgs, cv.COLOR_RGB2GRAY)
img_one = img_grays.reshape(1, 28, 28)/255.0
new_predictions = new_model.predict(img_one)
# 找出置信度最大的标签
print(names[np.argmax(new_predictions[0])])
# 查看预测图
image = plt.imread("./{}.jpg".format(indexs[img]))
predicted_label = names[np.argmax(new_predictions[0])]
plt.subplot(2,3,img+1)
plt.xticks([])
plt.yticks([])
plt.imshow(image)
plt.xlabel(names[np.argmax(new_predictions[0])], fontsize=16, color= 'c')
plt.show()
0-T恤
3-裙子
9-短靴
5-凉鞋
0-T恤
5-凉鞋
准确率需进行多次参数(批量数batch_size)调整及迭代次数(epochs)的选取 ↩︎