图像分类算法篇--基于LeNet的手写体数字识别

1、加载数据

from tensorflow import keras

# 加载数据

(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]


# 数据归一化
X_mean = X_train.mean(axis=0, keepdims=True)
X_std = X_train.std(axis=0, keepdims=True) + 1e-7
X_train = (X_train - X_mean) / X_std
X_valid = (X_valid - X_mean) / X_std
X_test = (X_test - X_mean) / X_std


构建LeNet-5

model=keras.models.Sequential([
    keras.layers.Conv2D(filters=6,kernel_size=[5,5],strides=1,padding="VALID",input_shape=(32,32,1),activation="tanh"),
    keras.layers.MaxPooling2D(pool_size=2),
    keras.layers.Conv2D(filters=16,kernel_size=[5,5],strides=1,padding="VALID",activation="tanh"),
    keras.layers.MaxPooling2D(pool_size=2),
    keras.layers.Conv2D(filters=120,kernel_size=[5,5],strides=1,padding="VALID",activation="tanh"),
    keras.layers.Flatten(),
    keras.layers.Dense(84,activation="tanh"),
    keras.layers.Dense(10,activation="softmax")
])

查看模型结构

from keras.utils.vis_utils import plot_model
plot_model(model=model,to_file="model_lenet.png",show_shapes=True)

图像分类算法篇--基于LeNet的手写体数字识别_第1张图片

 训练模型

model.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])
history = model.fit(X_train,y_train,epochs=10,validation_data=(X_valid,y_valid))

结果展示

图像分类算法篇--基于LeNet的手写体数字识别_第2张图片

你可能感兴趣的:(图像分类,计算机视觉,目标检测,tensorflow,分类,图像处理)