VGG的网络模型由 ssd_pascal.py 文件生成,可生成 VGG,ZF,Resnet101 和 Resnet152 四种网络。
# Create train net. # 生成训练网络 往下还有 test net 和 deploy net
net = caffe.NetSpec()
net.data, net.label = CreateAnnotatedDataLayer(train_data, batch_size=batch_size_per_device,
train=True, output_label=True, label_map_file=label_map_file,
transform_param=train_transform_param, batch_sampler=batch_sampler)
VGGNetBody(net, from_layer='data', fully_conv=True, reduced=True, dilated=True,
dropout=False) # 生成 VGGNet
AddExtraLayers(net, use_batchnorm, lr_mult=lr_mult)
mbox_layers = CreateMultiBoxHead(net, data_layer='data', from_layers=mbox_source_layers,
use_batchnorm=use_batchnorm, min_sizes=min_sizes, max_sizes=max_sizes,
aspect_ratios=aspect_ratios, steps=steps, normalizations=normalizations,
num_classes=num_classes, share_location=share_location, flip=flip, clip=clip,
prior_variance=prior_variance, kernel_size=3, pad=1, lr_mult=lr_mult)
VGGNetBody 函数在 caffe/python/caffe/model_libs.py 文件中定义,此外,还定义了其他网络,因此若要修改VGGNet网络模型,需修改model_libs.py 文件。
def VGGNetBody(net, from_layer, need_fc=True, fully_conv=False, reduced=False,
dilated=False, nopool=False, dropout=True, freeze_layers=[], dilate_pool4=False):
kwargs = {
'param': [dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)],
'weight_filler': dict(type='xavier'),
'bias_filler': dict(type='constant', value=0)}
assert from_layer in net.keys()
net.conv1_1 = L.Convolution(net[from_layer], num_output=64, pad=1, kernel_size=3, **kwargs)
net.relu1_1 = L.ReLU(net.conv1_1, in_place=True)
net.conv1_2 = L.Convolution(net.relu1_1, num_output=64, pad=1, kernel_size=3, **kwargs)
net.relu1_2 = L.ReLU(net.conv1_2, in_place=True)
准确率和召回率
P-R曲线
主要修改的文件
在src/caffe/proto/caffe.proto中的SolverParameter这个message下加上一个参数rec_prec_thr,该参数是判断样本是否为true positive (tp) 的score阈值,我们给他一个默认值0.6,代码如下(注意序列号在自己的SolverParameter最后的序列号上加1)
optional float rec_prec_thr = 46 [default = 0.6];
src/caffe/solver.cpp的void Solver::TestDetection(const int test_net_id)函数中加入计算recall和precision的代码
本人与参考博主的代码有所差异,修改后的文件放在了这里
1.测试集图片数量 = batch_size * test_iter
其中,batch_size 在 test.prototxt 文件中查看,test_iter 在 solver_test.prototxt 文件中查看