import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
def load_data():
path = “data/mnist.npz”
f = np.load(path)
x_train, y_train = f[‘x_train’], f[‘y_train’]
x_test, y_test = f[‘x_test’], f[‘y_test’]
f.close()
return (x_train, y_train), (x_test, y_test)
def train_y(y):
y_ohe = np.zeros(10)
y_ohe[y] = 1
return y_ohe
‘’‘读入数据’’’
(X_train, y_train), (X_test, y_test) = load_data()
print(X_train[0].shape)
print(y_train[0])
X_train = X_train.reshape(X_train.shape[0],28,28,1).astype(‘float32’)
X_test = X_test.reshape(X_test.shape[0],28,28,1).astype(‘float32’)
X_train /= 255
X_test /=255
‘’‘Onehot编码’’’
y_train_ohe = np.array([train_y(y_train[i]) for i in range(len(y_train))])
y_test_ohe = np.array([train_y(y_test[i]) for i in range(len(y_test))])
‘’‘搭建卷积神经网络’’’
model = Sequential()
model.add(Conv2D(filters=64, kernel_size=(3,3),strides=(1,1), padding=‘same’, input_shape=(28,28,1),activation=“relu”))#添加卷积层
model.add(MaxPooling2D(pool_size=(2,2)))#添加池化层
model.add(Dropout(0.5))#设置Dropout层
model.add(Conv2D(filters=128, kernel_size=(3,3),strides=(1,1), padding=‘same’, input_shape=(28,28,1),activation=“relu”))#添加卷积层
model.add(MaxPooling2D(pool_size=(2,2)))#添加池化层
model.add(Dropout(0.5))#设置Dropout层
model.add(Conv2D(filters=256, kernel_size=(3,3),strides=(1,1), padding=‘same’, input_shape=(28,28,1),activation=“relu”))#添加卷积层
model.add(MaxPooling2D(pool_size=(2,2)))#添加池化层
model.add(Dropout(0.5))#设置Dropout层
model.add(Flatten())#将当前节点展平
‘’‘构造全连接神经网络层’’’
model.add(Dense(128,activation=‘relu’))
model.add(Dense(64,activation=‘relu’))
model.add(Dense(32,activation=‘relu’))
model.add(Dense(10,activation=‘softmax’))
‘’‘定义损失函数,分类问题一般使用交叉熵’’’
model.compile(loss=‘categorical_crossentropy’, optimizer=‘adagrad’, metrics=[‘accuracy’])
‘’‘批量放入数据,进行训练’’’
model.fit(X_train,y_train_ohe, validation_data=(X_test, y_test_ohe),epochs=20,batch_size=128)
‘’‘测试集评价模型准确度’’’
scores = model.evaluate(X_test,y_test_ohe,verbose=0)
print(scores)