使用tensorflow实现图片分类
数据集
https://www.kaggle.com/datasets/gpiosenka/coffee-bean-dataset-resized-224-x-224、
1. 数据集下载解压之后存储在coffee-bean文件夹中,导入数据和可视化图片
import tensorflow as tf
import cv2, pathlib, splitfolders, math, os
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import *
import matplotlib.pyplot as plt
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import pathlib
data_dir= pathlib.Path("coffee-bean")
Ant = list(data_dir.glob("train/Ant/*"))
PIL.Image.open(str(Ant[0]))
Spider = list(data_dir.glob("train/Spider/*"))
PIL.Image.open(str(Spider[0]))
2. 定义图片的宽高,训练集验证集,使用keras.utils.image_dataset_from_directory()实现
batch_size =32
img_width = 180
img_height =180
train_data_dir = pathlib.Path("coffee-bean/train")
val_data_dir = pathlib.Path("coffee-bean/test")
train_data = tf.keras.utils.image_dataset_from_directory(
train_data_dir,
seed=123,
image_size=(img_height,img_width),
batch_size=batch_size
)
val_data = tf.keras.utils.image_dataset_from_directory(
val_data_dir,
seed=123,
image_size=(img_height,img_width),
batch_size=batch_size
)
3. 查看数据集的类别,标准化数据
class_names = train_data.class_names
class_names
# 标准化数据
normalization_layer = layers.Rescaling(1./255)
4. 构建模型,模型最后的输出为数据集的类别以及模型编译和训练
model = Sequential([
layers.Rescaling(1./255,input_shape=(img_height,img_width,3)),
layers.Conv2D(16,3,padding="same",activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32,3,padding="same",activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64,3,padding="same",activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128,activation='relu'),
layers.Dense(num_classess)
])
model.summary()
# 编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
epochs = 5
history = model.fit(
train_data,
validation_data=val_data,
epochs=epochs
)
5.可视化精确率和损失率
# 可视化验证结果
acc = history.history['accuracy']
val_acc=history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
https://tensorflow.google.cn/tutorials/images/classification