有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5
def main():
opts = get_argparser().parse_args()
if opts.dataset.lower() == 'voc':
opts.num_classes = 21
elif opts.dataset.lower() == 'cityscapes':
opts.num_classes = 19
# Setup visualization
vis = Visualizer(port=opts.vis_port, env=opts.vis_env) if opts.enable_vis else None
if vis is not None: # display options
vis.vis_table("Options", vars(opts))
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)
设置可视化工具和配置训练设备
opts.enable_vis
的值决定是否创建一个 Visualizer
对象,其端口和环境设置由 opts.vis_port
和 opts.vis_env
提供。Visualizer
是一个python可视化工具类,用于在训练过程中显示图像、图表等信息vis
对象的 vis_table
方法来显示配置选项。vars(opts)
是将 opts
对象转换为字典,其中包含了所有的命令行参数CUDA_VISIBLE_DEVICES
,其值为 opts.gpu_id
torch.manual_seed(opts.random_seed)
np.random.seed(opts.random_seed)
random.seed(opts.random_seed)
if opts.dataset=='voc' and not opts.crop_val:
opts.val_batch_size = 1
train_dst, val_dst = get_dataset(opts)
train_loader = data.DataLoader( train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=0)
val_loader = data.DataLoader( val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=0)
print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst)))
model_map = {
'deeplabv3_resnet50': network.deeplabv3_resnet50,
'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
'deeplabv3_resnet101': network.deeplabv3_resnet101,
'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
}
model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
if opts.separable_conv and 'plus' in opts.model:
network.convert_to_separable_conv(model.classifier)
utils.set_bn_momentum(model.backbone, momentum=0.01)
这部分设置网络的参数
定义一个模型映射字典,字典包括本项目可选择的多个网络:
deeplabv3的resnet50
deeplabv3+的resnet50
deeplabv3的resnet101
deeplabv3+的resnet101
deeplabv3的mobilenet
deeplabv3+的mobilenet,这些网络在Network文件夹中使用一定方法构建,在这里直接导入
从预设要选择的网络名称、类别数、输出通道数加载网络
检查是否启用可分离卷积、模型名称包含 ‘plus’
如果条件满足,则对模型的分类器部分应用可分离卷积的转换
调用set_bn_momentum函数,设置批量归一化的动量,set_bn_momentum函数:
def set_bn_momentum(model, momentum=0.1):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.momentum = momentum
metrics = StreamSegMetrics(opts.num_classes)
optimizer = torch.optim.SGD(params=[
{'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
{'params': model.classifier.parameters(), 'lr': opts.lr},
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
if opts.lr_policy=='poly':
scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
elif opts.lr_policy=='step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
if opts.loss_type == 'focal_loss':
criterion = utils.FocalLoss(ignore_index=255, size_average=True)
elif opts.loss_type == 'cross_entropy':
criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
评价指标、优化器、学习率、损失函数
utils.mkdir('checkpoints')
best_score = 0.0
cur_itrs = 0
cur_epochs = 0
if opts.ckpt is not None and os.path.isfile(opts.ckpt):
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"])
model = nn.DataParallel(model)
model.to(device)
if opts.continue_training:
optimizer.load_state_dict(checkpoint["optimizer_state"])
scheduler.load_state_dict(checkpoint["scheduler_state"])
cur_itrs = checkpoint["cur_itrs"]
best_score = checkpoint['best_score']
print("Training state restored from %s" % opts.ckpt)
print("Model restored from %s" % opts.ckpt)
del checkpoint # free memory
else:
print("[!] Retrain")
model = nn.DataParallel(model)
model.to(device)
checkpoints,检查点
vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples, np.int32) if opts.enable_vis else None
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images
if opts.test_only:
model.eval()
val_score, ret_samples = validate(
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
print(metrics.to_str(val_score))
return
训练中的可视化设置、测试模式处理、初始化变量
vis_sample_id
设为 None
validate
函数,返回评分、选定的样本 interval_loss = 0
while True: #cur_itrs < opts.total_itrs:
model.train()
cur_epochs += 1
for (images, labels) in train_loader:
cur_itrs += 1
images = images.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
np_loss = loss.detach().cpu().numpy()
interval_loss += np_loss
if vis is not None:
vis.vis_scalar('Loss', cur_itrs, np_loss)
if (cur_itrs) % 10 == 0:
interval_loss = interval_loss/10
print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
interval_loss = 0.0
可视化、日志记录部分:
'Loss'
是要记录的数据的名称,cur_itrs
是当前迭代次数,np_loss
是当前迭代的损失值 if (cur_itrs) % opts.val_interval == 0:
save_ckpt('checkpoints/latest_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride))
print("validation...")
model.eval()
val_score, ret_samples = validate(
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
print(metrics.to_str(val_score))
if val_score['Mean IoU'] > best_score: # save best model
best_score = val_score['Mean IoU']
save_ckpt('checkpoints/best_%s_%s_os%d.pth' % (opts.model, opts.dataset,opts.output_stride))
if vis is not None: # visualize validation score and samples
vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
vis.vis_table("[Val] Class IoU", val_score['Class IoU'])
for k, (img, target, lbl) in enumerate(ret_samples):
img = (denorm(img) * 255).astype(np.uint8)
target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
concat_img = np.concatenate((img, target, lbl), axis=2) # concat along width
vis.vis_image('Sample %d' % k, concat_img)
model.train()
scheduler.step()
if cur_itrs >= opts.total_itrs:
return
验证、模型保存、可视化反馈:
img
)、真实标签(target
)、模型预测的标签(lbl
),对返回的样本进行可视化:deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5