Pytorch分布式训练

接着上一篇写到加载Dataset,这里引进如何把Dataset分布在多卡进行训练
在多卡情况下分布式训练数据的读取用到了这两个代码

torch.nn.parallel.DistributedDataParallel
torch.utils.data.distributed.DistributedSampler
  1. dataparallel的做法是直接将batch切分到不同的卡。
  2. sampler确保dataloader只会load到整个数据集的一个特定子集的做法。
  3. DistributedSampler就是为每一个子进程划分出一部分数据集,以避免不同进程之间数据重复。

实例

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

dataset = your_dataset()
datasampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=datasampler)
model = your_model()

现在我们可以完整的看siamfc++里面怎么加载数据的,videoanalyst/data/builder.py

logger.info("Build real AdaptorDataset")
py_dataset = AdaptorDataset(task,
                            cfg,
                            num_epochs=cfg.num_epochs,
                            nr_image_per_epoch=cfg.nr_image_per_epoch)
# use DistributedSampler in case of DDP
if world_size > 1:
    py_sampler = DistributedSampler(py_dataset)
    logger.info("Use dist.DistributedSampler, world_size=%d" %
                world_size)
else:
    py_sampler = None
# build real dataloader
dataloader = DataLoader(
    py_dataset,
    batch_size=cfg.minibatch // world_size,
    shuffle=False,
    pin_memory=cfg.pin_memory,
    num_workers=cfg.num_workers // world_size,
    drop_last=True,
    sampler=py_sampler,
)

AdaptorDataset 构建数据集;
py_sampler 定义数据集划分;
dataloader 完成数据集加载

你可能感兴趣的:(siamfc++解析,pytorch)