第一章deeplabv3+源码之慢慢解析 根目录(1)main.py–get_argparser函数
第一章deeplabv3+源码之慢慢解析 根目录(2)main.py–get_dataset函数
第一章deeplabv3+源码之慢慢解析 根目录(3)main.py–validate函数
第一章deeplabv3+源码之慢慢解析 根目录(4)main.py–main函数
第一章deeplabv3+源码之慢慢解析 根目录(5)predict.py–get_argparser函数和main函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(1)voc.py–voc_cmap函数和download_extract函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(2)voc.py–VOCSegmentation类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(3)cityscapes.py–Cityscapes类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(4)utils.py–6个小函数
第三章deeplabv3+源码之慢慢解析 metrics文件夹stream_metrics.py–StreamSegMetrics类和AverageMeter类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a1)hrnetv2.py–4个函数和可执行代码
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a2)hrnetv2.py–Bottleneck类和BasicBlock类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a3)hrnetv2.py–StageModule类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a4)hrnetv2.py–HRNet类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b1)mobilenetv2.py–2个类和2个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b2)mobilenetv2.py–MobileNetV2类和mobilenet_v2函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c1)resnet.py–2个基础函数,BasicBlock类和Bottleneck类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c2)resnet.py–ResNet类和10个不同结构的调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d1)xception.py–SeparableConv2d类和Block类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d2)xception.py–Xception类和xception函数
第四章deeplabv3+源码之慢慢解析 network文件夹(2)_deeplab.py–ASPP相关的4个类和1个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(3)_deeplab.py–DeepLabV3类,DeepLabHeadV3Plus类和DeepLabHead类
第四章deeplabv3+源码之慢慢解析 network文件夹(4)modeling.py–5个私有函数(4个骨干网,1个模型载入)
第四章deeplabv3+源码之慢慢解析 network文件夹(5)modeling.py–12个调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(6)utils.py–_SimpleSegmentationModel类和IntermediateLayerGetter类
第五章deeplabv3+源码之慢慢解析 utils文件夹(1)ext_transforms.py.py–[17个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(2)loss.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(3)scheduler.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(4)utils.py–[1个类,4个函数]
第五章deeplabv3+源码之慢慢解析 utils文件夹(5)visualizer.py–[1个类]
总结
本篇介绍main.py中的第三个函数validate,主要涉及函数中所使用的验证数据。
提示:这个函数相对简单独立,主要是为了main函数服务,做交叉验证并返回具体的数据。很多地方需要结合main函数的语句一起了解,会在注释地方补充。
def validate(opts, model, loader, device, metrics, ret_samples_ids=None):
#私以为,此处最重要的就是搞清楚这6个参数都是做啥的,理清思路比较重要,剩下的语法是相对简单的。在main函数中,对validate函数的使用语句如下val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id),接下来对每个参数详解。
#opts前文get_argparser函数提过,就是在命令行窗口输入的命令参数解析后的结果。
#model就是所用的模型,,在main函数中model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride),具体需要在network文件夹中的modeling.py查看。后文详解。
#loader就是生成交叉验证的数据集。对应main函数中的val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2),其中data就是导入中的第11行,为了调用torch自带的dataloader(末尾有补充详细内容的链接。)
#device是指CPU或者选择的GPU,对应main函数中device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')。
#metrics度量,就是处理结果指标的矩阵,具体对应导入部分第14行,详见metrics文件夹下的stream_metrics.py文件,后文详解。
#ret_samples_ids默认值为None,对应main函数中 vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,np.int32) if opts.enable_vis else None 即可视化样本的索引,需要用到get_argparser函数中第84和90行,可视化是否开启和可视化样本数量。
"""Do validation and return specified samples"""
metrics.reset() #每次validation重置metrics矩阵,即初始化。
ret_samples = []
if opts.save_val_results: #get_argparser函数中第38行,即是否保存validation的结果。
if not os.path.exists('results'):
os.mkdir('results') #如果没有results文件夹,就建一个。
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) #归一化还原,详见utils文件夹下的utils.py。具体值来自imagnet训练结果,详见get_dataset函数第14行注释和对应的链接。
img_id = 0
with torch.no_grad(): #交叉验证阶段,非训练,不需要梯度更新
for i, (images, labels) in tqdm(enumerate(loader)): #从loader中读取序号,图像数据和标签
images = images.to(device, dtype=torch.float32) #数值加载到device
labels = labels.to(device, dtype=torch.long)
outputs = model(images) #图像images进入模型,然后输出为outputs
preds = outputs.detach().max(dim=1)[1].cpu().numpy() #detach()方法从原计算图返回tensor并不影响原图。max返回其中最大结果的值([0])和索引([1]),即此处返回索引。放到cpu,转为numpy格式。
targets = labels.cpu().numpy()
metrics.update(targets, preds) #比较实际结果和预测值,详见metrics文件夹下的stream_metrics.py文件。
if ret_samples_ids is not None and i in ret_samples_ids: # get vis samples获得可视化样本
ret_samples.append(
(images[0].detach().cpu().numpy(), targets[0], preds[0]))
if opts.save_val_results: #如保存validation结果,见get_argparser函数第38行
for i in range(len(images)):
image = images[i].detach().cpu().numpy()
target = targets[i]
pred = preds[i]
#下面的.transpose(1,2,0),将数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),转换后才可以显示。
image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8) #归一化除以了255,还原*255.
target = loader.dataset.decode_target(target).astype(np.uint8)
pred = loader.dataset.decode_target(pred).astype(np.uint8)
Image.fromarray(image).save('results/%d_image.png' % img_id)
Image.fromarray(target).save('results/%d_target.png' % img_id)
Image.fromarray(pred).save('results/%d_pred.png' % img_id)
fig = plt.figure()
plt.imshow(image)
plt.axis('off')
plt.imshow(pred, alpha=0.7)
ax = plt.gca()
ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator()) #主刻度设置,详见下面的补充链接。
ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
plt.close()
img_id += 1
score = metrics.get_results() #获得结果,详见metrics文件夹下的stream_metrics.py文件。
return score, ret_samples
Tips