深度学习基础系列:AlexNet
本文用keras实现AlexNet网络,用于猫狗分类项目,没有使用数据增强
参考资料
代码参考地址:https://github.com/ShuaiGuo95/DeepLearning 下的dogs_vs_cats_CNN.ipynb文件。
数据下载地址:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data
数据介绍
数据包含两部分,训练集和测试集,训练集有25000张图片,测试集有12500张图片,在训练集的图片名称中包含了图片的标签信息,而测试集的图片名称代表的图片的id。
AlexNet的模型的代码如下,使用keras框架
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization
resize = 224
def model_alexnet():
# AlexNet
model = Sequential()
# 第一段
model.add(Conv2D(filters=96, kernel_size=(11, 11),
strides=(4, 4), padding='valid',
input_shape=(resize, resize, 3),
activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(3, 3),
strides=(2, 2),
padding='valid'))
# 第二段
model.add(Conv2D(filters=256, kernel_size=(5, 5),
strides=(1, 1), padding='same',
activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(3, 3),
strides=(2, 2),
padding='valid'))
# 第三段
model.add(Conv2D(filters=384, kernel_size=(3, 3),
strides=(1, 1), padding='same',
activation='relu'))
model.add(Conv2D(filters=384, kernel_size=(3, 3),
strides=(1, 1), padding='same',
activation='relu'))
model.add(Conv2D(filters=256, kernel_size=(3, 3),
strides=(1, 1), padding='same',
activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3),
strides=(2, 2), padding='valid'))
# 第四段
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1000, activation='relu'))
model.add(Dropout(0.5))
# Output Layer
model.add(Dense(2))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.summary()
return model
模型的输出如下:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 54, 54, 96) 34944
_________________________________________________________________
batch_normalization_1 (Batch (None, 54, 54, 96) 384
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 26, 26, 96) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 26, 26, 256) 614656
_________________________________________________________________
batch_normalization_2 (Batch (None, 26, 26, 256) 1024
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 12, 12, 256) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 12, 12, 384) 885120
_________________________________________________________________
conv2d_4 (Conv2D) (None, 12, 12, 384) 1327488
_________________________________________________________________
conv2d_5 (Conv2D) (None, 12, 12, 256) 884992
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 5, 5, 256) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 6400) 0
_________________________________________________________________
dense_1 (Dense) (None, 4096) 26218496
_________________________________________________________________
dropout_1 (Dropout) (None, 4096) 0
_________________________________________________________________
dense_2 (Dense) (None, 4096) 16781312
_________________________________________________________________
dropout_2 (Dropout) (None, 4096) 0
_________________________________________________________________
dense_3 (Dense) (None, 1000) 4097000
_________________________________________________________________
dropout_3 (Dropout) (None, 1000) 0
_________________________________________________________________
dense_4 (Dense) (None, 2) 2002
_________________________________________________________________
activation_1 (Activation) (None, 2) 0
=================================================================
训练模型的代码如下:
import os
import numpy as np
import keras
from tqdm import tqdm
import matplotlib.pyplot as plt
from model_AlexNet import *
import datetime
import cv2
DATA_DIR = "J:/dataset/CatVsDog/train/"
resize = 224
def load_data():
imgs = os.listdir(DATA_DIR)
num = len(imgs)
train_data = np.empty((5000, resize, resize, 3), dtype="int32")
train_label = np.empty((5000, ), dtype="int32")
test_data = np.empty((5000, resize, resize, 3), dtype="int32")
test_label = np.empty((5000, ), dtype="int32")
for i in tqdm(range(5000)):
if i % 2:
train_data[i] = cv2.resize(cv2.imread(DATA_DIR + 'dog.' + str(i) + '.jpg'), (resize, resize))
train_label[i] = 1
else:
train_data[i] = cv2.resize(cv2.imread(DATA_DIR + 'cat.' + str(i) + '.jpg'), (resize, resize))
train_label[i] = 0
for i in tqdm(range(5000, 10000)):
if i % 2:
test_data[i-5000] = cv2.resize(cv2.imread(DATA_DIR + 'dog.' + str(i) + '.jpg'), (resize, resize))
test_label[i-5000] = 1
else:
test_data[i-5000] = cv2.resize(cv2.imread(DATA_DIR + 'cat.' + str(i) + '.jpg'), (resize, resize))
test_label[i-5000] = 0
return train_data, train_label, test_data, test_label
train_data, train_label, test_data, test_label = load_data()
train_data, test_data = train_data.astype('float32'), test_data.astype('float32')
train_data, test_data = train_data/255, test_data/255
train_label = keras.utils.to_categorical(train_label, 2)
test_label = keras.utils.to_categorical(test_label, 2)
start=datetime.datetime.now()
model = model_alexnet()
his = model.fit(train_data, train_label,
batch_size = 64,
epochs = 50,
validation_split = 0.2,
shuffle = True)
end=datetime.datetime.now()
print('Running time: %s Seconds'%(end-start))
print(his.history.keys())
plt.plot(his.history['acc'])
plt.plot(his.history['val_acc'])
plt.title('model_accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plt.plot(his.history['loss'])
plt.plot(his.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
最后一次训练结果如下:
loss: 0.0756 - acc: 0.9785 - val_loss: 0.6687 - val_acc: 0.8070
acc还行,但是val_acc不高,有过拟合的情况