【keras】2.Lenet-5实现

lenet结构
【keras】2.Lenet-5实现_第1张图片
【keras】2.Lenet-5实现_第2张图片
Lenet-5代码

# -*- coding: utf-8 -*-
"""
Created on Tue Mar  6 19:45:01 2018

@author: yuyangyg
"""

import keras
from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

x_train = x_train / 255.
x_test = x_test / 255.

#CNN将输入变成[batch,channel,width,height]
x_train = x_train.reshape(-1, 28, 28, 1)
x_test=x_test.reshape(-1,28,28,1)

from keras.layers import Conv2D, MaxPool2D, Dense, Flatten
from keras.models import Sequential

lenet = Sequential()
lenet.add(Conv2D(6, kernel_size=3, strides=1, padding='same', input_shape=(28, 28, 1)))
lenet.add(MaxPool2D(pool_size=2, strides=2))
lenet.add(Conv2D(16, kernel_size=5, strides=1, padding='valid'))
lenet.add(MaxPool2D(pool_size=2, strides=2))
lenet.add(Flatten())#多维向量压成一维
lenet.add(Dense(120))
lenet.add(Dense(84))
lenet.add(Dense(10, activation='softmax'))


lenet.summary()

#from keras.utils import plot_model
#plot_model(lenet, to_file='lenet.png', show_shapes=True)

lenet.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy'])

lenet.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=[x_test, y_test])

lenet.save('myletnet.h5')  #保存训练好的模型

https://www.jianshu.com/p/7a0a3eefeea4
https://github.com/SherlockLiao/lenet/blob/master/Lenet.ipynb

你可能感兴趣的:(深度学习)