deeplabv3+源码之慢慢解析4 第一章根目录(4)main.py--main函数

系列文章目录

第一章deeplabv3+源码之慢慢解析 根目录(1)main.py–get_argparser函数
第一章deeplabv3+源码之慢慢解析 根目录(2)main.py–get_dataset函数
第一章deeplabv3+源码之慢慢解析 根目录(3)main.py–validate函数
第一章deeplabv3+源码之慢慢解析 根目录(4)main.py–main函数
第一章deeplabv3+源码之慢慢解析 根目录(5)predict.py–get_argparser函数和main函数

第二章deeplabv3+源码之慢慢解析 datasets文件夹(1)voc.py–voc_cmap函数和download_extract函数
第二章deeplabv3+源码之慢慢解析 datasets文件夹(2)voc.py–VOCSegmentation类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(3)cityscapes.py–Cityscapes类
第二章deeplabv3+源码之慢慢解析 datasets文件夹(4)utils.py–6个小函数

第三章deeplabv3+源码之慢慢解析 metrics文件夹stream_metrics.py–StreamSegMetrics类和AverageMeter类

第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a1)hrnetv2.py–4个函数和可执行代码
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a2)hrnetv2.py–Bottleneck类和BasicBlock类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a3)hrnetv2.py–StageModule类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(a4)hrnetv2.py–HRNet类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b1)mobilenetv2.py–2个类和2个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(b2)mobilenetv2.py–MobileNetV2类和mobilenet_v2函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c1)resnet.py–2个基础函数,BasicBlock类和Bottleneck类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(c2)resnet.py–ResNet类和10个不同结构的调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d1)xception.py–SeparableConv2d类和Block类
第四章deeplabv3+源码之慢慢解析 network文件夹(1)backbone文件夹(d2)xception.py–Xception类和xception函数
第四章deeplabv3+源码之慢慢解析 network文件夹(2)_deeplab.py–ASPP相关的4个类和1个函数
第四章deeplabv3+源码之慢慢解析 network文件夹(3)_deeplab.py–DeepLabV3类,DeepLabHeadV3Plus类和DeepLabHead类
第四章deeplabv3+源码之慢慢解析 network文件夹(4)modeling.py–5个私有函数(4个骨干网,1个模型载入)
第四章deeplabv3+源码之慢慢解析 network文件夹(5)modeling.py–12个调用函数
第四章deeplabv3+源码之慢慢解析 network文件夹(6)utils.py–_SimpleSegmentationModel类和IntermediateLayerGetter类

第五章deeplabv3+源码之慢慢解析 utils文件夹(1)ext_transforms.py.py–[17个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(2)loss.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(3)scheduler.py–[1个类]
第五章deeplabv3+源码之慢慢解析 utils文件夹(4)utils.py–[1个类,4个函数]
第五章deeplabv3+源码之慢慢解析 utils文件夹(5)visualizer.py–[1个类]
总结

文章目录

  • 系列文章目录
    • 第一章deeplabv3+源码之慢慢解析根目录(4)main.py--main函数
    • 主函数,main函数
  • 总结


第一章deeplabv3+源码之慢慢解析根目录(4)main.py–main函数

本篇介绍main.py中的最后一个函数main,是整个程序的主函数,帮助了解整体思路。

主函数,main函数

提示:main函数是整个程序的主体思路,很多地方需要结合其他程序去理解,但新手不必着急,踏踏实实的理解每个程序段的意义,具体细节一步一步看就行。

def main():
    opts = get_argparser().parse_args()    #获得命令行参数,用get_argparser函数解析。
    if opts.dataset.lower() == 'voc':   #VOC数据集21个分类,cityscapes数据集19个分类,自己的数据集酌情处理。
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization   #可视化代码
    vis = Visualizer(port=opts.vis_port,      #Visualizer导入部分第21行,get_argparser函数第86、88行,详见utils文件夹下visualizer.py代码
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   #优先使用GPU。
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)   #随机种子设定。get_argparser函数第70行,默认值1。
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val: #crop_val裁剪验证,get_argparser函数第50行,默认是False.
        opts.val_batch_size = 1    #get_argparser函数第54行,默认是4.

    train_dst, val_dst = get_dataset(opts)     #get_dataset函数,获得数据集和验证集
    train_loader = data.DataLoader(
        train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2,
        drop_last=True)  # drop_last=True to ignore single-image batches.
    val_loader = data.DataLoader(
        val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model (all models are 'constructed at network.modeling) 详见network文件夹中modeling.py
    model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:    #可分离卷积,get_argparser函数第30行,默认是False.因此代码可选deeplab V3(无separable_conv)和V3 plus(有separable_conv)两个模型,所以用‘plus’一起判断。
        network.convert_to_separable_conv(model.classifier)  #后面补充可分离卷积链接。
    utils.set_bn_momentum(model.backbone, momentum=0.01)  #动量momentum学习率衰减,后面补充链接。#详见utils文件夹下的utils.py代码,设置momentum=0.01。

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes) #初始化metrics,详见metrics文件夹下的stream_metrics.py代码。
    
    # Set up optimizer 以下是设置优化参数
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':   #get_argparser函数第46行,默认值'poly'。关于学习率策略,很多内容,后面补充链接,很重要,建议新手看看。
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

    # Set up criterion
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':      #get_argparser函数第62行,默认值'cross_entropy'.损失函数选择。
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    def save_ckpt(path):                  #保存模型
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt): #get_argparser函数第58行,恢复从checkpoint中断的训练。就是继续断点训练的意思。
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,
                                      np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 从归一化恢复原图像

    if opts.test_only:
        model.eval()   #评估模式。而非训练模式。网络参数不会发生变化。
        val_score, ret_samples = validate(   #调用validate函数
            opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id) 
        print(metrics.to_str(val_score))
        return

    interval_loss = 0
    while True:  # cur_itrs < opts.total_itrs:   #小于总迭代次数,get_argparser函数第41行,默认值30k。
        # =====  Train  =====训练过程
        model.train()     #torch.nn.modules中module.py内置的train(),可简单的理解为自动训练。
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()   #optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备,此代码为单单次循环清空。配合optimizer.step()使用。梯度累加和梯度清空后面有补充链接。
            outputs = model(images)
            loss = criterion(outputs, labels)   #使用前面选择的损失函数进行计算。
            loss.backward()                    #反向传播       
            optimizer.step()      #更新权重,使用没有经过optimizer.zero_grad()清空之前的所有梯度累加进行更新,此代码为单次循环更新。

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)

            if (cur_itrs) % 10 == 0:        #每10次输出一次迭代和损失值。
                interval_loss = interval_loss / 10
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

            if (cur_itrs) % opts.val_interval == 0:      #opts.val_intervalget_argparser函数第74行,默认值100,即每100次,保存一次checkpoints。
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()    #模型评估,非训练。
                val_score, ret_samples = validate(
                    opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
                    ret_samples_ids=vis_sample_id)   #调用validate函数。
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model 保存当前Mean IoU值最高的模型。
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                if vis is not None:  # visualize validation score and samples #详见utils文件夹下visualizer.py代码
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)   #还原图像,用以显示。
                        target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)   #decode_target函数,详见第二章datasets文件夹(2)voc.py--VOCSegmentation类
                        lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()  #学习率更新

            if cur_itrs >= opts.total_itrs:
                return

Tips

  1. 补充可分离卷积说明
  2. 补充动量学习率衰减
  3. 补充关于学习率
  4. 补充梯度累加和梯度清空
  5. 主函数main已完成,应该对整体的流程有所了解。感兴趣的细节,需要后面逐个方法去探索。

总结

  1. 整个main.py的这些内容,学习后没记住没关系,入门学习要有耐心,反复查找的看也就了解了。

  2. 根目录下一共两个python文件,一是main.py,二是predict.py。下一章介绍的predict.py,相对main.py而言,这个纯预测的文件,要简单的多,很多代码是重复的。

你可能感兴趣的:(技术,deeplabV3+,语义分割,深度学习,人工智能)