SegFormer学习笔记(3)train续1

一、GPU设置

咱们关心一下GPU怎么设置的吧。

在上一篇文章中的train.py中,

第126行 gpu = setup_ddp()

对于我来说,gpu=0

那么,129行,main(cfg, gpu, save_dir)的gpu就等于0了。

上到第25行,def main(cfg, gpu, save_dir):就把gpu=0传进来了

你在看46-50行, if train_cfg['DDP']: 其实咱们再配置文件中,这个DDP选项都是false。其实这个DDP就是分布式训练,需要多GPU。因此,gpu变量,由于不是DDP,所以根本没用上,白传递到main中了。

    if train_cfg['DDP']: 
        sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)
        model = DDP(model, device_ids=[gpu])
    else:
        sampler = RandomSampler(trainset)

所以条件不满足,那么就跳到50行: sampler = RandomSampler(trainset)

这里,RandomSampler随机采样,其定义如下:

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:

replacement参数决定是否进行重复采样,如果replacement为True,将采用torch.randint来生成随机索引,其中索引序列是会存在重复值的,如果replacement为False,将采用torch.randperm函数来生成随机索引序列,此时序列不包含重复数值。

二、device设置

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device = torch.device(cfg['DEVICE'])

上面注释掉的, #device = torch.device(cfg['DEVICE']),才是源码。我改成了

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

你再看看44行,73、64行,就知道,作者要把数据倒到device里面了。

你可能感兴趣的:(人工智能之SegFormer,学习)