第一章deeplabv3+源码之慢慢解析 根目录(1)main.py–get_argparser函数
第一章deeplabv3+源码之慢慢解析 根目录(2)main.py–get_dataset函数
第一章deeplabv3+源码之慢慢解析 根目录(3)main.py–validate函数
第一章deeplabv3+源码之慢慢解析 根目录(4)main.py–main函数
第一章deeplabv3+源码之慢慢解析 根目录(5)predict.py–get_argparser函数和main函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(1)voc.py–voc_cmap函数和download_extract函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(2)voc.py–VOCSegmentation类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(3)cityscapes.py–Cityscapes类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(4)utils.py–6个小函数
第三章deeplabv3+源码之慢慢解析 metrics文件夹stream_metrics.py–StreamSegMetrics类和AverageMeter类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a1)hrnetv2.py–4个函数和可执行代码
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a2)hrnetv2.py–Bottleneck类和BasicBlock类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a3)hrnetv2.py–StageModule类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a4)hrnetv2.py–HRNet类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b1)mobilenetv2.py–2个类和2个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b2)mobilenetv2.py–MobileNetV2类和mobilenet_v2函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c1)resnet.py–2个基础函数,BasicBlock类和Bottleneck类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c2)resnet.py–ResNet类和10个不同结构的调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d1)xception.py–SeparableConv2d类和Block类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d2)xception.py–Xception类和xception函数
第四章deeplabv3+源码之慢慢解析 network文件夹(2)_deeplab.py–ASPP相关的4个类和1个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(3)_deeplab.py–DeepLabV3类,DeepLabHeadV3Plus类和DeepLabHead类
第四章deeplabv3+源码之慢慢解析 network文件夹(4)modeling.py–5个私有函数(4个骨干网,1个模型载入)
第四章deeplabv3+源码之慢慢解析 network文件夹(5)modeling.py–12个调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(6)utils.py–_SimpleSegmentationModel类和IntermediateLayerGetter类
第五章deeplabv3+源码之慢慢解析 utils文件夹(1)ext_transforms.py.py–[17个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(2)loss.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(3)scheduler.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(4)utils.py–[1个类,4个函数]
第五章deeplabv3+源码之慢慢解析 utils文件夹(5)visualizer.py–[1个类]
总结
本篇介绍main.py中的最后一个函数main,是整个程序的主函数,帮助了解整体思路。
提示:main函数是整个程序的主体思路,很多地方需要结合其他程序去理解,但新手不必着急,踏踏实实的理解每个程序段的意义,具体细节一步一步看就行。
def main():
opts = get_argparser().parse_args() #获得命令行参数,用get_argparser函数解析。
if opts.dataset.lower() == 'voc': #VOC数据集21个分类,cityscapes数据集19个分类,自己的数据集酌情处理。
opts.num_classes = 21
elif opts.dataset.lower() == 'cityscapes':
opts.num_classes = 19
# Setup visualization #可视化代码
vis = Visualizer(port=opts.vis_port, #Visualizer导入部分第21行,get_argparser函数第86、88行,详见utils文件夹下visualizer.py代码
env=opts.vis_env) if opts.enable_vis else None
if vis is not None: # display options
vis.vis_table("Options", vars(opts))
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #优先使用GPU。
print("Device: %s" % device)
# Setup random seed
torch.manual_seed(opts.random_seed) #随机种子设定。get_argparser函数第70行,默认值1。
np.random.seed(opts.random_seed)
random.seed(opts.random_seed)
# Setup dataloader
if opts.dataset == 'voc' and not opts.crop_val: #crop_val裁剪验证,get_argparser函数第50行,默认是False.
opts.val_batch_size = 1 #get_argparser函数第54行,默认是4.
train_dst, val_dst = get_dataset(opts) #get_dataset函数,获得数据集和验证集
train_loader = data.DataLoader(
train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2,
drop_last=True) # drop_last=True to ignore single-image batches.
val_loader = data.DataLoader(
val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
print("Dataset: %s, Train set: %d, Val set: %d" %
(opts.dataset, len(train_dst), len(val_dst)))
# Set up model (all models are 'constructed at network.modeling) 详见network文件夹中modeling.py
model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
if opts.separable_conv and 'plus' in opts.model: #可分离卷积,get_argparser函数第30行,默认是False.因此代码可选deeplab V3(无separable_conv)和V3 plus(有separable_conv)两个模型,所以用‘plus’一起判断。
network.convert_to_separable_conv(model.classifier) #后面补充可分离卷积链接。
utils.set_bn_momentum(model.backbone, momentum=0.01) #动量momentum学习率衰减,后面补充链接。#详见utils文件夹下的utils.py代码,设置momentum=0.01。
# Set up metrics
metrics = StreamSegMetrics(opts.num_classes) #初始化metrics,详见metrics文件夹下的stream_metrics.py代码。
# Set up optimizer 以下是设置优化参数
optimizer = torch.optim.SGD(params=[
{'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
{'params': model.classifier.parameters(), 'lr': opts.lr},
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
# optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
# torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
if opts.lr_policy == 'poly': #get_argparser函数第46行,默认值'poly'。关于学习率策略,很多内容,后面补充链接,很重要,建议新手看看。
scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
elif opts.lr_policy == 'step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
# Set up criterion
# criterion = utils.get_loss(opts.loss_type)
if opts.loss_type == 'focal_loss': #get_argparser函数第62行,默认值'cross_entropy'.损失函数选择。
criterion = utils.FocalLoss(ignore_index=255, size_average=True)
elif opts.loss_type == 'cross_entropy':
criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
def save_ckpt(path): #保存模型
""" save current model
"""
torch.save({
"cur_itrs": cur_itrs,
"model_state": model.module.state_dict(),
"optimizer_state": optimizer.state_dict(),
"scheduler_state": scheduler.state_dict(),
"best_score": best_score,
}, path)
print("Model saved as %s" % path)
utils.mkdir('checkpoints')
# Restore
best_score = 0.0
cur_itrs = 0
cur_epochs = 0
if opts.ckpt is not None and os.path.isfile(opts.ckpt): #get_argparser函数第58行,恢复从checkpoint中断的训练。就是继续断点训练的意思。
# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"])
model = nn.DataParallel(model)
model.to(device)
if opts.continue_training:
optimizer.load_state_dict(checkpoint["optimizer_state"])
scheduler.load_state_dict(checkpoint["scheduler_state"])
cur_itrs = checkpoint["cur_itrs"]
best_score = checkpoint['best_score']
print("Training state restored from %s" % opts.ckpt)
print("Model restored from %s" % opts.ckpt)
del checkpoint # free memory
else:
print("[!] Retrain")
model = nn.DataParallel(model)
model.to(device)
# ========== Train Loop ==========#
vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,
np.int32) if opts.enable_vis else None # sample idxs for visualization
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 从归一化恢复原图像
if opts.test_only:
model.eval() #评估模式。而非训练模式。网络参数不会发生变化。
val_score, ret_samples = validate( #调用validate函数
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
print(metrics.to_str(val_score))
return
interval_loss = 0
while True: # cur_itrs < opts.total_itrs: #小于总迭代次数,get_argparser函数第41行,默认值30k。
# ===== Train =====训练过程
model.train() #torch.nn.modules中module.py内置的train(),可简单的理解为自动训练。
cur_epochs += 1
for (images, labels) in train_loader:
cur_itrs += 1
images = images.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad() #optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备,此代码为单单次循环清空。配合optimizer.step()使用。梯度累加和梯度清空后面有补充链接。
outputs = model(images)
loss = criterion(outputs, labels) #使用前面选择的损失函数进行计算。
loss.backward() #反向传播
optimizer.step() #更新权重,使用没有经过optimizer.zero_grad()清空之前的所有梯度累加进行更新,此代码为单次循环更新。
np_loss = loss.detach().cpu().numpy()
interval_loss += np_loss
if vis is not None:
vis.vis_scalar('Loss', cur_itrs, np_loss)
if (cur_itrs) % 10 == 0: #每10次输出一次迭代和损失值。
interval_loss = interval_loss / 10
print("Epoch %d, Itrs %d/%d, Loss=%f" %
(cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
interval_loss = 0.0
if (cur_itrs) % opts.val_interval == 0: #opts.val_intervalget_argparser函数第74行,默认值100,即每100次,保存一次checkpoints。
save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
(opts.model, opts.dataset, opts.output_stride))
print("validation...")
model.eval() #模型评估,非训练。
val_score, ret_samples = validate(
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
ret_samples_ids=vis_sample_id) #调用validate函数。
print(metrics.to_str(val_score))
if val_score['Mean IoU'] > best_score: # save best model 保存当前Mean IoU值最高的模型。
best_score = val_score['Mean IoU']
save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
(opts.model, opts.dataset, opts.output_stride))
if vis is not None: # visualize validation score and samples #详见utils文件夹下visualizer.py代码
vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
vis.vis_table("[Val] Class IoU", val_score['Class IoU'])
for k, (img, target, lbl) in enumerate(ret_samples):
img = (denorm(img) * 255).astype(np.uint8) #还原图像,用以显示。
target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8) #decode_target函数,详见第二章datasets文件夹(2)voc.py--VOCSegmentation类
lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
concat_img = np.concatenate((img, target, lbl), axis=2) # concat along width
vis.vis_image('Sample %d' % k, concat_img)
model.train()
scheduler.step() #学习率更新
if cur_itrs >= opts.total_itrs:
return
Tips
整个main.py的这些内容,学习后没记住没关系,入门学习要有耐心,反复查找的看也就了解了。
根目录下一共两个python文件,一是main.py,二是predict.py。下一章介绍的predict.py,相对main.py而言,这个纯预测的文件,要简单的多,很多代码是重复的。