#设置GPU
import tensorflow as tf
gpus=tf.config.list_physical_devices("GPU")
if gpus:
gpus0=gpus[0]#如果有多个GPU,则仅仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0,True)#设置GPUS显存按需使用
tf.config.set_visible_devices([gpus0],"GPUS")
from tensorflow.keras import datasets,layers,models
import matplotlib.pyplot as plt
(train_images,train_labels),(test_images,test_labels)=datasets.cifar10.load_data()
由于像素点最大值为155,最小值为0,这里直接让他们除以255.0即可完成数据标准化
#像素标准化
train_images,test_images=train_images/255.0,test_images/255.0
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
class_name=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
plt.figure(figsize=(20,20))
for i in range(20):
plt.subplot(5,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i],cmap=plt.cm.binary)
plt.xlabel(class_name[train_labels[i][0]])
#搭建模型
model=models.Sequential([
layers.Conv2D(32,(3,3),activation='relu',input_shape=(32,32,3)),
layers.MaxPool2D((2,2)),
layers.Conv2D(64,(3,3),activation='relu'),
layers.MaxPooling2D((2,2)),
layers.Conv2D(64,(3,3),activation='relu'),
layers.Flatten(),
layers.Dense(64,activation='relu'),
layers.Dense(10)
]
)
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
history=model.fit(train_images,train_labels,epochs=20,validation_data=(test_images,test_labels))
Epoch 1/20
1563/1563 [==============================] - 29s 17ms/step - loss: 1.5689 - accuracy: 0.4253 - val_loss: 1.2995 - val_accuracy: 0.5338
Epoch 2/20
1563/1563 [==============================] - 28s 18ms/step - loss: 1.2116 - accuracy: 0.5707 - val_loss: 1.1985 - val_accuracy: 0.5798
Epoch 3/20
1563/1563 [==============================] - 27s 17ms/step - loss: 1.0499 - accuracy: 0.6286 - val_loss: 1.0068 - val_accuracy: 0.6500
Epoch 4/20
1563/1563 [==============================] - 26s 17ms/step - loss: 0.9501 - accuracy: 0.6669 - val_loss: 0.9589 - val_accuracy: 0.6664
Epoch 5/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.8733 - accuracy: 0.6961 - val_loss: 0.9424 - val_accuracy: 0.6728
Epoch 6/20
1563/1563 [==============================] - 26s 17ms/step - loss: 0.8169 - accuracy: 0.7130 - val_loss: 0.9840 - val_accuracy: 0.6608
Epoch 7/20
1563/1563 [==============================] - 27s 18ms/step - loss: 0.7566 - accuracy: 0.7348 - val_loss: 0.8962 - val_accuracy: 0.6925
Epoch 8/20
1563/1563 [==============================] - 28s 18ms/step - loss: 0.7127 - accuracy: 0.7510 - val_loss: 0.9304 - val_accuracy: 0.6812
Epoch 9/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.6721 - accuracy: 0.7640 - val_loss: 0.8733 - val_accuracy: 0.7022
Epoch 10/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.6339 - accuracy: 0.7771 - val_loss: 0.9032 - val_accuracy: 0.6993
Epoch 11/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.6008 - accuracy: 0.7880 - val_loss: 0.9038 - val_accuracy: 0.7067
Epoch 12/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.5655 - accuracy: 0.7982 - val_loss: 0.8896 - val_accuracy: 0.7067
Epoch 13/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.5337 - accuracy: 0.8115 - val_loss: 0.9292 - val_accuracy: 0.6994
Epoch 14/20
1563/1563 [==============================] - 26s 17ms/step - loss: 0.5067 - accuracy: 0.8173 - val_loss: 0.9405 - val_accuracy: 0.7077
Epoch 15/20
1563/1563 [==============================] - 28s 18ms/step - loss: 0.4762 - accuracy: 0.8319 - val_loss: 0.9913 - val_accuracy: 0.6900
Epoch 16/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.4546 - accuracy: 0.8382 - val_loss: 1.0190 - val_accuracy: 0.7094
Epoch 17/20
1563/1563 [==============================] - 26s 17ms/step - loss: 0.4279 - accuracy: 0.8476 - val_loss: 1.0170 - val_accuracy: 0.6988
Epoch 18/20
1563/1563 [==============================] - 27s 17ms/step - loss: 0.4046 - accuracy: 0.8553 - val_loss: 1.0816 - val_accuracy: 0.6956
Epoch 19/20
1563/1563 [==============================] - 26s 17ms/step - loss: 0.3849 - accuracy: 0.8621 - val_loss: 1.1430 - val_accuracy: 0.6888
Epoch 20/20
1563/1563 [==============================] - 25s 16ms/step - loss: 0.3599 - accuracy: 0.8723 - val_loss: 1.1654 - val_accuracy: 0.6961
plt.imshow(test_images[1])
模型预测:
import numpy as np
pre=model.predict(test_images)
print(class_name[np.argmax(pre[1])])
模型评估:
#模型评估
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'],label='accuracy')
plt.plot(history.history['val_accuracy'],label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5,1])
plt.legend(loc="lower right")
plt.show()
test_loss,test_acc=model.evaluate(test_images,test_labels,verbose=2)
print(test_acc)
本次彩色图片的识别之后,我们可以看到搭建的CNN网络过于简单,我建议大家可以搭建更加复杂的或者使用更加复杂得网络进行图片识别,本实验准确率不是很高,有待提升,大家也可以使用更加复杂得或者其他网络进行实验发现idea