在create_model方法中,创建YOLO v3的网络结构,其中参数:
代码如下
def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze_body=2,
weights_path='model_data/yolo_weights.h5'):
**yolo_weight.h5是由.weight转换而来,转换程序为convert.py文件 **
逻辑
将参数进行处理:
代码如下
h, w = input_shape # 尺寸
image_input = Input(shape=(w, h, 3)) # 图片输入格式
num_anchors = len(anchors) # anchor数量
# YOLO的三种尺度,每个尺度的anchor数,类别数+边框4个+置信度1
y_true = [Input(shape=(h // {0: 32, 1: 16, 2: 8}[l], w // {0: 32, 1: 16, 2: 8}[l],
num_anchors // 3, num_classes + 5)) for l in range(3)]
h,w都经过了相应尺寸的缩放,对应各自的三种anchor值,有几个类别,就会设置多少个初始值为0的空间,最后只有预测的那一位置1
其中,真值y_true,真值即Ground Truth:
“//”是Python语法中的整除符号,通过循环创建3个Input层,组成列表,作为y_true,假设class为1,格式如下:
Tensor("input_2:0", shape=(?, 13, 13, 3, 6), dtype=float32)
Tensor("input_3:0", shape=(?, 26, 26, 3, 6), dtype=float32)
Tensor("input_4:0", shape=(?, 52, 52, 3, 6), dtype=float32)
其中,第1位是样本数,第2~3位是特征图的尺寸13x13,第4位是每个图的anchor数,第5位是:类别(n)+4个框值(x,y,w,h)+框的置信度(是否含有物体)。
通过图片输入Input层image_input、每个尺度的anchor数num_anchors//3、类别数num_classes,创建YOLO v3的网络结构,即:
model_body = yolo_body(image_input, num_anchors // 3, num_classes)
shape=(?,416,416,3,6)
接着,加载预训练模型:
根据预训练模型的地址weights_path,加载模型,按名称对应by_name,略过不匹配skip_mismatch;
选择冻结模式:
代码如下
if load_pretrained: # 加载预训练模型
model_body.load_weights(weights_path, by_name=True, skip_mismatch=True)
if freeze_body in [1, 2]:
# Freeze darknet53 body or freeze all but 3 output layers.
num = (185, len(model_body.layers) - 3)[freeze_body - 1]
for i in range(num):
model_body.layers[i].trainable = False # 将其他层的训练关闭
接着,设置模型损失层model_loss:
model_body.output指代的是经过神经网络后得到的模型输出
代码如下
model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss',
arguments={'anchors': anchors,
'num_classes': num_classes,
'ignore_thresh': 0.5}
)(model_body.output ,* y_true)
接着,创建最终模型:
模型的输入:model_body的输入层,即image_input,和y_true;
模型的输出:model_loss的输出,一个值,output_shape=(1,);
保存模型的网络图plot_model,和打印网络model.summary();
代码如下
model = Model(inputs=[model_body.input] + y_true, outputs=model_loss) # 模型
plot_model(model, to_file=os.path.join('model_data', 'model.png'), show_shapes=True, show_layer_names=True) # 存储网络结构
model.summary() # 打印网络
其中,model_body.input是任意(?)个(416,416,3)的图片,即
Tensor("input_1:0", shape=(?, 416, 416, 3), dtype=float32)
y_true是已标注数据转换的真值结构,即
[Tensor("input_2:0", shape=(?, 13, 13, 3, 6), dtype=float32),
Tensor("input_3:0", shape=(?, 26, 26, 3, 6), dtype=float32),
Tensor("input_4:0", shape=(?, 52, 52, 3, 6), dtype=float32)]
参考:
http://www.jintiankansha.me/t/D7hikktv8s