源码解析目标检测的跨界之星DETR(一)、概述与模型推断
源码解析目标检测的跨界之星DETR(二)、模型训练过程与数据处理
源码解析目标检测的跨界之星DETR(三)、Backbone与位置编码
Transformer 修炼之道(三)、Decoder
源码解析目标检测的跨界之星DETR(四)、Detection with Transformer
算法学习笔记(5):匈牙利算法
源码解析目标检测的跨界之星DETR(六)、模型输出与预测生成
init_distributed_mode() 方法是与分布式训练相关的设置,在该方法里,是通过环境变量来判断是否使用分布式训练,如果是,那么就设置相关参数,具体可参考 util/misc.py 文件中的源码,这里不作解析。
#分布式训练相关的设置
utils.init_distributed_mode(args)
print("git:\n {}\n".format(utils.get_sha()))
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)#设置设备
args.dist_backend = 'nccl' #设置后端为NCCL
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)#初始化默认的分布式进程组
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
参数项 frozen_weights 代表是否固定住参数的权重,类似于迁移学习的微调。如果是,那么需要同时指定 masks 参数,代表这种条件仅适用于分割任务。上图最后部分是固定随机种子,以便复现结果。
#冻结训练只用于分割
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
# fix the seed for reproducibility(固定种子的可复制性)
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
加载训练数据和验证数据
dataset_train = build_dataset(image_set='train', args=args)
dataset_val = build_dataset(image_set='val', args=args)
if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)#一个采样器以产生小批量索引
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
如果需要模型冻结,加载冻结模型
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
model_without_ddp.detr.load_state_dict(checkpoint['model']) #加载冻结模型