在前半部分,我们已经完成了前两大步,并决定使用ResNet50预训练网络来训练模型。那么接下来,就让我们引入keras中已经封装好的ResNet50预训练网络参数。
base_model = ResNet50(weights='imagenet',
include_top=False,
input_shape=(img_width, img_height, 3))
base_model.trainable = False
x = base_model.output
x = GlobalAveragePooling2D(name='average_pool')(x)
predictions = Dense(class_num, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.RMSprop(lr=1e-3),
metrics=['acc'])
keras中已经封装好了resnet50的网络结构和预训练参数,可以通过
from keras.applications.resnet50 import ResNet50
来引入,并使用以上的代码进行实例化。
那么,让我们赶紧来跑一下吧。
最终,我们得到的验证集准确率为77.59%
这的确是非常大的提升,而且由于ResNet更加的轻量化,因此其训练速度更快。
可惜的是,依旧没有达到一期的任务期望。
optimizer=optimizers.RMSprop(lr=5e-4)
此时,我们得到的验证集准确率为79.35%
# 这个文档用来进行去除图片的边界,以及调整大小到356*356
import os
import cv2
from PIL import Image
import warnings
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.simplefilter("ignore", category=FutureWarning)
Image.MAX_IMAGE_PIXELS = None
base_path = "input"
image_path = os.path.join(base_path, 'image')
train_path = os.path.join(base_path, 'train')
def removeBorder(fileDir):
pathDir = os.listdir(fileDir) # 取图片的原始路径
for imgName in pathDir:
img = cv2.imread(fileDir + '/' + imgName)
imgtemp = cv2.resize(img, (456, 456))
cropped = imgtemp[50:406, 50:406]
cv2.imwrite(fileDir + '/' + imgName, cropped)
if __name__ == '__main__':
count = 0
for i in range(1, 1505):
fileDir = os.path.join(image_path, str(i))
if os.path.isdir(fileDir):
removeBorder(fileDir)
count += 1
print("processed :" + str(i) + ",and " + str(count) + "/168")
此时,我们得到的验证集准确率为83.57%,距离一期的任务期望仅一步之遥!
def preprocess(image):
mean = [R_MEAN, G_MEAN, B_MEAN]
image[..., 0] -= mean[0]
image[..., 1] -= mean[1]
image[..., 2] -= mean[2]
return image
# 可以对训练集进行数据增强处理
train_datagen = ImageDataGenerator(preprocessing_function=preprocess,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
fill_mode='constant'
)
# 测试集不许动,去均值中心化完了之后不许动
validation_datagen = ImageDataGenerator(preprocessing_function=preprocess)
最后,我们得到的验证集准确率为84.31%,虽然增加的不多,但是终于完成了一期的任务!
下一期,我会讲一下如何利用模型微调的方式完成二期任务。
有什么不明白的参数,可以查看keras官方中文文档
https://keras-cn.readthedocs.io/en/latest/other/application/