DAFormer代码学习

DAFormer复现

1、demo

文章:链接
先下载tools/download_checkpoints.sh,运行demo文件python -m demo.image_demo demo/demo.png work_dirs/211108_1622_gta2cs_daformer_s0_7f24c/211108_1622_gta2cs_daformer_s0_7f24c.json work_dirs/211108_1622_gta2cs_daformer_s0_7f24c/latest.pth(-m 把demo文件夹下的demo_image.py文件作为demo.image_demo模块,参考链接)
在image_demo.py文件中,

def main():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file')              # 图片文件
    parser.add_argument('config', help='Config file')          # 配置文件
    parser.add_argument('checkpoint', help='Checkpoint file')  # 模型文件
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='cityscapes',
        help='Color palette used for segmentation map')
    parser.add_argument(
        '--opacity',
        type=float,
        default=0.5,
        help='Opacity of painted segmentation map. In (0, 1] range.')
    args = parser.parse_args()

下载数据集,数据结构:

DAFormer
├── ...
├── data
│   ├── acdc (optional)
│   │   ├── gt
│   │   │   ├── train
│   │   │   ├── val
│   │   ├── rgb_anon
│   │   │   ├── train
│   │   │   ├── val
│   ├── cityscapes
│   │   ├── leftImg8bit
│   │   │   ├── train
│   │   │   ├── val
│   │   ├── gtFine
│   │   │   ├── train
│   │   │   ├── val
│   ├── dark_zurich (optional)
│   │   ├── gt
│   │   │   ├── val
│   │   ├── rgb_anon
│   │   │   ├── train
│   │   │   ├── val
│   ├── gta
│   │   ├── images
│   │   ├── labels
│   ├── synthia (optional)
│   │   ├── RGB
│   │   ├── GT
│   │   │   ├── LABELS
├── ...

输出:
Save prediction to demo/demo_pred.png

2、 数据预处理:

python tools/convert_datasets/gta.py data/gta --nproc 8
python tools/convert_datasets/cityscapes.py data/cityscapes --nproc 8
python tools/convert_datasets/synthia.py data/synthia/ --nproc 8
# tools/convert_datasets/gta.py
def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert GTA annotations to TrainIds')
    parser.add_argument('gta_path', help='gta data path')
    parser.add_argument('--gt-dir', default='labels', type=str)
    parser.add_argument('-o', '--out-dir', help='output path')
    parser.add_argument(
        '--nproc', default=4, type=int, help='number of process')
    args = parser.parse_args()
    return args
    
# tools/convert_datasets/cityscapes.py
def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert Cityscapes annotations to TrainIds')
    parser.add_argument('cityscapes_path', help='cityscapes data path')
    parser.add_argument('--gt-dir', default='gtFine', type=str)
    parser.add_argument('-o', '--out-dir', help='output path')
    parser.add_argument(
        '--nproc', default=1, type=int, help='number of process')
    args = parser.parse_args()
    return args
    
# tools/convert_datasets/synthia.py
def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert SYNTHIA annotations to TrainIds')
    parser.add_argument('synthia_path', help='gta data path')
    parser.add_argument('--gt-dir', default='GT/LABELS', type=str)
    parser.add_argument('-o', '--out-dir', help='output path')
    parser.add_argument(
        '--nproc', default=4, type=int, help='number of process')
    args = parser.parse_args()
    return args

前两个数据数据集没问题,第三个数据集报错:

"""
Traceback (most recent call last):
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "tools/convert_datasets/synthia.py", line 19, in convert_to_train_id
    label = cv2.imread(file, cv2.IMREAD_UNCHANGED)[:, :, -1]
TypeError: 'NoneType' object is not subscriptable
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "tools/convert_datasets/synthia.py", line 123, in 
    main()
  File "tools/convert_datasets/synthia.py", line 110, in main
    sample_class_stats = mmcv.track_parallel_progress(
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/site-packages/mmcv/utils/progressbar.py", line 164, in track_parallel_progress
    for result in gen:
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/multiprocessing/pool.py", line 865, in next
    raise value
TypeError: 'NoneType' object is not subscriptable

TypeError:' NoneType '对象不是下标,在synthia中找到一张图片有问题,data/synthia/GT/LABELS/0002675.png图片大小为0KB,删除图片没有问题,剩余813张图片可以跑通。

3、训练:

提供了最终DAFormer带注释的配置文件(configs/daformer/gta2cs_uda_warm_fdthings_rcs_croppl_a999_daformer_mitb5_s0.py)

python run_experiments.py --config configs/daformer/gta2cs_uda_warm_fdthings_rcs_croppl_a999_daformer_mitb5_s0.py

报错:

2022-10-21 15:28:16,899 - mmseg - INFO - Use load_from_local loader
Traceback (most recent call last):
  File "run_experiments.py", line 106, in 
    train.main([config_files[i]])
  File "/home/l*/environ/DAFormer/tools/train.py", line 145, in main
    model.init_weights()
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/site-packages/mmcv/runner/base_module.py", line 55, in init_weights
    m.init_weights()
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/site-packages/mmcv/runner/base_module.py", line 55, in init_weights
    m.init_weights()
  File "/home/l*/environ/DAFormer/mmseg/models/backbones/mix_transformer.py", line 349, in init_weights
    checkpoint = _load_checkpoint(
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/site-packages/mmcv/runner/checkpoint.py", line 451, in _load_checkpoint
    return CheckpointLoader.load_checkpoint(filename, map_location, logger)
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/site-packages/mmcv/runner/checkpoint.py", line 244, in load_checkpoint
    return checkpoint_loader(filename, map_location)
  File "/home/l*/anaconda3/envs/daformer/lib/python3.8/site-packages/mmcv/runner/checkpoint.py", line 260, in load_from_local
    raise IOError(f'{filename} is not a checkpoint file')
OSError: pretrained/mit_b5.pth is not a checkpoint file

.pth的路径写错了???仔细看报错问题,开始是在/home/l*/environ/DAFormer/tools/train.py", line 145,所以找到train.py文件,

# tools/train.py
def parse_args(args):
    parser = argparse.ArgumentParser(description='Train a segmentor')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--load-from', help='the checkpoint file to load weights from')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')
    parser.add_argument(
        '--no-validate',
        action='store_true',
        help='whether not to evaluate the checkpoint during training')
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
        '--gpus',
        type=int,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='ids of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    parser.add_argument(
        '--options', nargs='+', action=DictAction, help='custom options')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args(args)

有一个参数--load-from,加上checkpoints的路径,试试,还是不行。

仔细看OSError: pretrained/mit_b5.pth is not a checkpoint filemit_b5.pth已经下载,在config同级目录下,新建pretained文件夹,在将所有下载的mit_b*.pth文件放在pretrained文件夹下。运行成功。

训练了不知道多少个小时之后,得到结果。
DAFormer代码学习_第1张图片

对于文中的实验(例如:网络体系结构比较、组件消融、. . .),使用一个系统来自动生成和训练配置:

python run_experiments.py --exp <ID>

# 参考位置:run_experiments.py
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        '--exp',
        type=int,
        default=None,
        help='Experiment id as defined in experiment.py',
    )
    group.add_argument(
        '--config',
        default=None,
        help='Path to config file',
    )
    parser.add_argument(
        '--machine', type=str, choices=['local'], default='local')
    parser.add_argument('--debug', action='store_true')
    args = parser.parse_args()
    assert (args.config is None) != (args.exp is None), \
        'Either config or exp has to be defined.'

    GEN_CONFIG_DIR = 'configs/generated/'
    JOB_DIR = 'jobs'
    cfgs, config_files = [], []

run_experiments.py中查找参数含义,'--exp', type=int, default=None, help='Experiment id as defined in experiment.py',所以在experiment.py文件中,查找,得到id=1/2/3/4/5/6/7/8/100/101。(有关可用实验及其分配ID的更多信息,可以在Experiment.py 中找到。生成的配置将存储在configs / generate / 中。)这部分实验没做~~~

4、测试和预测

提供的DAFormer checkpoint在GTA→Cityscapes上进行了训练,先下载tools/download _checkpoint.sh,在Cityscapes验证集上测试使用:

sh test.sh work_dirs/211108_1622_gta2cs_daformer_s0_7f24c

sh是linux中运行shell的命令,是shell的解释器,shell脚本是linux中壳层与命令行界面,用户可以在shell脚本输入命令来执行各种各样的任务。要运行shell脚本,首选需要给shell脚本权限,这里里以test.sh文件为例,接着先给hello.sh文件添加x权限chmod u+x hello.sh,输入sh hello.sh就开始执行shell脚本了。
实际情况是:报错

test.sh: 9: Bad substitution

参考链接
修改命令为

bash -x test.sh work_dirs/211108_1622_gta2cs_daformer_s0_7f24c

运行成功。
DAFormer代码学习_第2张图片

将预测结果保存到work _ dirs/211108 _ 1622 _ gta2cs _ daformer _ s0 _ 7f24c/preds 进行检验,并将模型的mIoU打印到控制台。提供的检查点应达到68.85 mIoU。参考work _ dirs/211108 _ 1622 _ gta2cs _ daformer _ s0 _ 7f24c/ 20211108 _ 164105.log的结尾,以获得更多的信息,如分类的IoU。

5、其他实验

当评估在Synthia→Cityscapes上训练的模型时,请注意,评估脚本计算所有19个Cityscapes类的mIoU。然而,Synthia只包含这些类中16个类的标签。因此,仅在这16个类上报告Synthia→CityscapesmIoU是UDA中的常见做法。由于3个缺失类的Iou为0,因此可以进行转换mIoU16 = mIoU19 * 19 / 16
在目标数据集的测试分割上报告了Cityscapes→ACDCCityscapes→Dark_Zurich的结果。为了生成测试集的预测,请运行:

python -m tools.test path/to/config_file path/to/checkpoint_file --test-set --format-only --eval-option imgfile_prefix=labelTrainIds to_label_id=False

预测结果可以提交给各自数据集的公共评估服务器以获得测试分数。

你可能感兴趣的:(python,人工智能,深度学习,ubuntu)