就我个人看来,Keras应该是目前最方便的深度学习框架之一了。一般而言,Keras可以使用Tensorflow或者Theano作为后端引擎,但是我不会用Theano,所以自然用的是Tensorflow。由于Keras实现cifar10分类网络模型的难度过低,也没有什么特别值得注意的点……除了感叹下框架的强大之外好像也没啥可以说的了……
不过,Keras虽然非常方便,但是它的框架结构都是封装好的,导致在需要自己写框架中不存在的特定结构的时候会比较麻烦,这时候可以采用keras+其他底层框架,如tensorflow或者pytorch的形式。
那么首先,还是导入数据:
import numpy as np
import keras
from keras.models import Model, save_model, load_model
from keras.layers import Input, Dense, Dropout, BatchNormalization
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D
## data
import pickle
data_batch_1 = pickle.load(open("cifar-10-batches-py/data_batch_1", 'rb'), encoding='bytes')
data_batch_2 = pickle.load(open("cifar-10-batches-py/data_batch_2", 'rb'), encoding='bytes')
data_batch_3 = pickle.load(open("cifar-10-batches-py/data_batch_3", 'rb'), encoding='bytes')
data_batch_4 = pickle.load(open("cifar-10-batches-py/data_batch_4", 'rb'), encoding='bytes')
data_batch_5 = pickle.load(open("cifar-10-batches-py/data_batch_5", 'rb'), encoding='bytes')
train_X_1 = data_batch_1[b'data']
train_X_1 = train_X_1.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_1 = data_batch_1[b'labels']
train_X_2 = data_batch_2[b'data']
train_X_2 = train_X_2.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_2 = data_batch_2[b'labels']
train_X_3 = data_batch_3[b'data']
train_X_3 = train_X_3.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_3 = data_batch_3[b'labels']
train_X_4 = data_batch_4[b'data']
train_X_4 = train_X_4.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_4 = data_batch_4[b'labels']
train_X_5 = data_batch_5[b'data']
train_X_5 = train_X_5.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
train_Y_5 = data_batch_5[b'labels']
train_X = np.row_stack((train_X_1, train_X_2))
train_X = np.row_stack((train_X, train_X_3))
train_X = np.row_stack((train_X, train_X_4))
train_X = np.row_stack((train_X, train_X_5))
train_Y = np.row_stack((train_Y_1, train_Y_2))
train_Y = np.row_stack((train_Y, train_Y_3))
train_Y = np.row_stack((train_Y, train_Y_4))
train_Y = np.row_stack((train_Y, train_Y_5))
train_Y = train_Y.reshape(50000, 1).transpose(0, 1).astype("int32")
train_Y = keras.utils.to_categorical(train_Y)
test_batch = pickle.load(open("cifar-10-batches-py/test_batch", 'rb'), encoding='bytes')
test_X = test_batch[b'data']
test_X = test_X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
test_Y = test_batch[b'labels']
test_Y = keras.utils.to_categorical(test_Y)
train_X /= 255
test_X /= 255
虽然看着很麻烦,但是姑且表达的应该还算清楚吧……keras自带了导入cifar10的函数,但是导入的目录是固定的,所以就自己写一个会好一点的感觉。
之后就是建立网络模型,在tensorflow上看着蛮复杂的,keras上实现起来难度就低很多了:
inpt = Input(shape=(32,32,3))
x = Conv2D(64, (3, 3), padding='same', activation='relu')(inpt)
x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = Dropout(0.1)(x)
x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = AveragePooling2D(pool_size=(2, 2), strides=2)(x)
x = Dropout(0.1)(x)
x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(10, activation='softmax')(x)
model = Model(inpt, x)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
kears有个我特别喜欢的功能,就是.summary()这个,可以超级直观的展示网络的结构,我个人觉得远比看tensorboard舒服……个人感觉哈。
之后开始训练:
for ii in range(10):
print("Epoch:", ii+1)
model.fit(train_X, train_Y, batch_size=100, epochs=1, verbose=1)
score = model.evaluate(test_X, test_Y, verbose=1)
print('Test loss =', score[0])
print('Test accuracy =', score[1])
核心就是一句 model.fit,相比Tensorflow简单太多了……
之后就是保存读取模型,以及预测的写法了:
save_model(model,'cifar10.h5')
model = load_model('cifar10.h5')
pred_Y = model.predict(test_X)
score = model.evaluate(test_X, test_Y, verbose=0)
print('Test loss =', score[0])
print('Test accuracy =', score[1])
虽然keras方便的一塌糊涂,但是在学习阶段我觉得还是应该尝试下Pytorch或者Tensorflow这些更加基层一点的框架,毕竟在灵活性上keras是比较低的。但是kaggle上诸多大神都是用keras来打比赛,我觉得说明某种意义上,不是特别偏科研的用途的话,keras还是能够胜任的。
最后附上完整代码地址:https://github.com/PolarisShi/cifar10