【Keras】- LeNet-5

【Keras】- LeNet-5

文章目录

  • 【Keras】- LeNet-5
    • 1.难点说明
      • 1.mnist 无法下载
      • 2.kerea指定GPU
    • 2.训练时的回调函数
      • 1.提前终止条件
      • 2.记录训练日志
      • 3.持久化训练好的模型
      • 4.减低学习率
    • 3.打印网络信息
    • 代码

1.难点说明

1.mnist 无法下载

mnist.npz

>>>find /home -mane keras
/home/ggp/anacoddas/lib/python3.6/site-packages/keras

'''
打开 mnist.py
里面的地址默认为
'''
path = get_file(path, origin='https://s3.amazonaws.com/img-datasets/mnist.npz')

自己编写个mnist.py函数

import numpy as np 

def load_data(path='./mnist.npz'):
	f = np.load(path)
	x_train, y_train = f['x_train'], f['y_train']
	x_test, y_test = f['x_test'], f['y_test']
	f.close()
	return (x_train, y_train), (x_test, y_test)

2.kerea指定GPU

CUDA_VISIBLE_DEVICES=1 python train.py

# 查看PID端口使用情况
>>> ps -au

2.训练时的回调函数

1.提前终止条件

keras.callbacks.EarlyStopping(monitor='val_loss', # 训练过程中被监视的数据
                              min_delta=0,        # 小于min_delta 的觉得变化被认为没有提升
                              patience=0,         # 没有提升的训练轮数
                              verbose=0,
                              mode='auto',
                              baseline=None,
                              restore_best_weights=False
                             )
'''
@ mode 1.'min'——被监视的数据停止下降,训练停止
       2.’max‘--被监视的数据停止上升,训练停止      
'''

2.记录训练日志

csv_logger 回调函数将日志保存为 csv 文件格式

3.持久化训练好的模型

在每个训练期之后保存模型

keras.callbacks.ModelCheckpoint(filepath, 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=False, 
                                save_weights_only=False, 
                                mode='auto', 
                                period=1)

4.减低学习率

当标准评估停止提升时,降低学习速率

keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 
                                  factor=0.1,   # 新的学习速率 = 学习速率 * 因数
                                  patience=10, 
                                  verbose=0, 
                                  mode='auto', 
                                  min_delta=0.0001,
                                  cooldown=0,
                                  min_lr=0)

3.打印网络信息

summary()

代码

【Keras】- LeNet-5_第1张图片

import mnist

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Conv2D, MaxPooling2D, Flatten
from keras.optimizers import SGD
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, CSVLogger,ModelCheckpoint
from keras.utils import np_utils

patience = 10
log_file_path = './log.csv'
trained_models_path = './model/LeNet-5'

(x_train, y_train),(x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0],28,28,1) / 255.0
x_test = x_test.reshape(x_test.shape[0],28,28,1) / 255.0
'''转化为one-hot标签'''
y_train = np_utils.to_categorical(y_train, num_classes= 10)
y_test = np_utils.to_categorical(y_test, num_classes=10)

'''回调函数'''
early_stop = EarlyStopping('loss', 0.1,patience=patience)
reduce_lr = ReduceLROnPlateau('loss',factor=0.1,patience=int(patience/2),verbose=1)
csv_logger = CSVLogger(log_file_path,append=False)
models_names = trained_models_path + '.{epoch:02d}-{acc:2f}.hdf5'

model_chechpoint = ModelCheckpoint(models_names,monitor='loss',verbose=1,save_best_only=True,save_weights_only=False)
callbacks = [model_chechpoint,csv_logger,early_stop,reduce_lr]

'''LeNet-5---------------------------------------------------------------------------------
线性模型
'''
model = Sequential()
model.add(Conv2D(filters=6,kernel_size=(5,5),padding='valid',
                 input_shape=     (28,28,1),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(filters=16,kernel_size=(5,5),padding='valid',activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(120,activation='relu'))
model.add(Dense(84,activation='relu'))
model.add(Dense(10,activation='softmax'))
'''------------------------------------------------------------------------'''
sgd = SGD(lr=0.05,decay=1e-6,momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy',metrics=['accuracy'])
model.summary()

model.fit(x_train,y_train,batch_size=128,epochs=100,validation_data=(x_test,y_test),
	callbacks=callbacks,verbose=1,shuffle=True)


你可能感兴趣的:(深度学习)