模型脚本:
from tensorflow.keras import layers, models, Model, Sequential
#定义分类网络结构 即最后的全连接层
def VGG(feature, im_height=224, im_width=224, class_num=1000):#feature是提取特征的网络结构
# tensorflow中的tensor通道排序是NHWC
input_image = layers.Input(shape=(im_height, im_width, 3), dtype="float32")
x = feature(input_image)#提取特征得到输出
x = layers.Flatten()(x)#展平处理
x = layers.Dropout(rate=0.5)(x)#加一个dropout方法 减小过拟合
x = layers.Dense(2048, activation='relu')(x)#为了节省训练参数 设置原论文一半的节点
x = layers.Dropout(rate=0.5)(x)
x = layers.Dense(2048, activation='relu')(x)
x = layers.Dense(class_num)(x)
output = layers.Softmax()(x)
model = models.Model(inputs=input_image, outputs=output)
return model
#通过配置列表生成提取特征的网络结构
def features(cfg):
feature_layers = []#用来存放层结构
for v in cfg:#通过for循环来遍历配置列表
if v == "M":#说明该层是最大池化层
feature_layers.append(layers.MaxPool2D(pool_size=2, strides=2))
else:
conv2d = layers.Conv2D(v, kernel_size=3, padding="SAME", activation="relu")
feature_layers.append(conv2d)
return Sequential(feature_layers, name="feature")#name是给网络结构起的一个名字
#字典:用来存储不同配置的模型结构 键是模型的配置文件,值是列表类型,其中的数字代表卷积层卷积核的个数,M代表池化层的结构(最大池化操作)
cfgs = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def vgg(model_name="vgg16", im_height=224, im_width=224, class_num=1000):#实例化模型 参数一:字典的key
try:
cfg = cfgs[model_name]#获得值
except:
print("Warning: model number {} not in cfgs dict!".format(model_name))
exit(-1)
model = VGG(features(cfg), im_height=im_height, im_width=im_width, class_num=class_num)
return model
model = vgg(model_name='vgg11')
训练脚本:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from model import vgg
import tensorflow as tf
import json
import os
data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = data_root + "/data_set/flower_data/" # flower data set path
train_dir = image_path + "train"
validation_dir = image_path + "val"
# create direction for saving weights
if not os.path.exists("save_weights"):
os.makedirs("save_weights")
im_height = 224
im_width = 224
batch_size = 10
epochs = 10
# 预处理
train_image_generator = ImageDataGenerator(rescale=1. / 255,#简单的缩放
horizontal_flip=True)#水平方向的随机翻转
validation_image_generator = ImageDataGenerator(rescale=1. / 255)#定义验证集生成器
#读取训练集图像文件
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_train = train_data_gen.n#获得训练集训练样本的个数
#字典类型,返回每个类别和其索引
class_indices = train_data_gen.class_indices
# 将key和value进行反转 得到反过来的字典 (目的:在预测的过程中通过索引直接对应到类别中)
inverse_dict = dict((val, key) for key, val in class_indices.items())
# python对象转换成json对象的一个过程,生成的是字符串。
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:#将所得到的字典写入到json文件当中
json_file.write(json_str)
#读取验证集图像文件
val_data_gen = train_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_val = val_data_gen.n
model = vgg("vgg16", 224, 224, 5)#实例化网络
model.summary()
# using keras high level api for training
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=["accuracy"])
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
save_best_only=True,
save_weights_only=True,
monitor='val_loss')]
# tensorflow2.1 recommend to using fit
history = model.fit(x=train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=callbacks)
预测脚本:
from model import vgg
from PIL import Image
import numpy as np
import json
import matplotlib.pyplot as plt
im_height = 224
im_width = 224
# load image
img = Image.open("../tulip.jpg")
# resize image to 224x224
img = img.resize((im_width, im_height))
plt.imshow(img)
# scaling pixel value to (0-1)
img = np.array(img) / 255.
# Add the image to a batch where it's the only member.
img = (np.expand_dims(img, 0))
# read class_indict
try:
json_file = open('./class_indices.json', 'r')
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
model = vgg("vgg16", 224, 224, 5)
model.load_weights("./save_weights/myVGG.h5")
result = np.squeeze(model.predict(img))
predict_class = np.argmax(result)
print(class_indict[str(predict_class)], result[predict_class])
plt.show()