基于Transform的目标检测(DETR模型)之模型解析

参考博文

源码解析目标检测的跨界之星DETR(一)、概述与模型推断

源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理

源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码

Transformer 修炼之道(三)、Decoder

源码解析目标检测的跨界之星DETR(四)、Detection with Transformer

算法学习笔记(5):匈牙利算法

源码解析目标检测的跨界之星DETR(六)、模型输出与预测生成

init_distributed_mode() 方法是与分布式训练相关的设置,在该方法里,是通过环境变量来判断是否使用分布式训练,如果是,那么就设置相关参数,具体可参考 util/misc.py 文件中的源码,这里不作解析。

 #分布式训练相关的设置
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))

def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)#设置设备
    args.dist_backend = 'nccl' #设置后端为NCCL
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)#初始化默认的分布式进程组
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)

 参数项 frozen_weights 代表是否固定住参数的权重,类似于迁移学习的微调。如果是,那么需要同时指定 masks 参数,代表这种条件仅适用于分割任务。上图最后部分是固定随机种子,以便复现结果。

 #冻结训练只用于分割
    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"
    print(args)
# fix the seed for reproducibility(固定种子的可复制性)
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

加载训练数据和验证数据 

dataset_train = build_dataset(image_set='train', args=args)
    dataset_val = build_dataset(image_set='val', args=args)

    if args.distributed:
        sampler_train = DistributedSampler(dataset_train)
        sampler_val = DistributedSampler(dataset_val, shuffle=False)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(
        sampler_train, args.batch_size, drop_last=True)#一个采样器以产生小批量索引

    data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                   collate_fn=utils.collate_fn, num_workers=args.num_workers)
    data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)

如果需要模型冻结,加载冻结模型 

if args.frozen_weights is not None:
        checkpoint = torch.load(args.frozen_weights, map_location='cpu')
        model_without_ddp.detr.load_state_dict(checkpoint['model']) #加载冻结模型

你可能感兴趣的:(深度学习,深度学习)