keras速度复习-CNN

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten#二维卷积,二维池化,扁平化
from keras.optimizers import Adam

#载入数据
(x_train,y_train),(x_test,y_test)=mnist.load_data()
#(60000,28,28)->(60000,28,28,1)
x_train=x_train.reshape(-1,28,28,1)/255.0#shape0就是60000,-1自动计算28*28
x_test=x_test.reshape(-1,28,28,1)/255.0
#换one hot格式:把输出训练成10个类
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)

#定义顺序模型
model=Sequential()
#第一个卷积层
#input_shape输入平面
#filters 卷积核/滤波器个数
#kernel_size 卷积窗口大小
#strides步长
#padding padding方式 same/valid
#activation激活函数
model.add(Convolution2D(
        input_shape=(28,28,1),
        filters=32,#滤波器
        kernel_size=5,#核径
        strides=1,#步长
        padding='same',#扩充边缘以备卷积
        activation='relu'#激活函数
        ))
#第一个池化层
model.add(MaxPooling2D(
        pool_size=2,#池化尺寸
        strides=2,#步长,池化后由28X28变成了14X14
        padding='same'#扩充边缘以备卷积
        ))
#第二个卷积层
model.add(Convolution2D(64,5,strides=1,padding='same',activation='relu'))
#第二个池化层
model.add(MaxPooling2D(2,2,'same'))#变成了7X7
#把第二个池化层的输出扁平化为1维
model.add(Flatten())#64个特征图,每个图都是7X7,一共有64*7*7这么多个一维的数据
#第一个全连接层
model.add(Dense(1024,activation='relu'))#上面的3136与当前1024是全连接的
#Dropout
model.add(Dropout(0.5))#训练时50%的神经元不工作
#第二个全连接层
model.add(Dense(10,activation='softmax'))#10个分类(0~9)

#定义优化器
adam=Adam(lr=1e-4)
#优化策略,损失估计,目标质量
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])

#训练模型,每次训练32组数据,共需60000/32次训练,这叫一个周期,一共训练3个周期
model.fit(x_train,y_train,batch_size=64,epochs=1)

#评估模型
loss,accuracy=model.evaluate(x_test,y_test)
print('\ntest loss',loss)
print('accuracy',accuracy)

loss,accuracy=model.evaluate(x_train,y_train)
print('\ntest loss',loss)
print('accuracy',accuracy)
#对比一一个程序,层中加了dropout算法,最后评估了训练数据

padding:https://blog.csdn.net/baidu_36161077/article/details/81165531

same填充,保证边缘信息且图片不会越卷越小,valid不填充

你可能感兴趣的:(机器学习(MOOC笔记代码))