在pytorch DDP数据并行时会对数据集进行切分,每个rank节点只处理部分数据。使用DistributedSampler来会把dataset数据集采样为一个子数据集。定义如下:
torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)
world_size
rank
DistributedSampler
时,torch.util.data.Dataloader创建时的shuffle参数,相当于把随机的功能交给了DistributedSampler
。默认为Truenum_replicas
整除;为False的话Sampler为增加额外的indices;默认为False注意在分布式模式下,每个epoch启动前要调用set_epoch()
方法,用于在多个epoch执行时打乱顺序,不调用的话读取顺序都会一样。
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
WebDataset是专门针对大数据训练服务的。基于Pytorch IterableDataset实现的数据DataLoader,数据存储在一系列的POSIX tar包中,使用squence/streaming的数据访问方式。跟AIStore服务器和Tensorcom RDMA库结合,可以提供高性能的数据访问方式。
WebDataset会把一个大数据集切分为多个shard,一个tar包就是一个shard。跟pytorch的DataLoader不同的是,WebDataset是以shard为粒度进行I/O并行访问和shuffle。一组shard可以用一个文件的列表来表示,也可以写到一个大括号的方式来进行表示,例如字符串openimages-train-{000000..000554}.tar
表示数据集中包含有554个shard,每个shard分片中有1G的图像数据。在WebDataset中,这种ShardList
字符串的解析是通过braceexpand
库来进行的。以下两种表示是等价的:
dataset = wds.WebDataset(["dataset-000.tar", "dataset-001.tar", "dataset-002.tar", "dataset-003.tar"])
dataset = wds.WebDataset("dataset-{000..003}.tar")
WebDataset基本使用方式如下:
import webdataset as wds
dataset = wds.WebDataset(url).shuffle(1000).decode("torchrgb").to_tuple("jpg;png", "json")
dataloader = torch.utils.data.DataLoader(dataset, num_workers=4, batch_size=16)
for inputs, outputs in dataloader:
...
webdataset.Webdataset
使用方法简单, 仅用一行代码, 初始化会自动按node数和worker数对shard进行切分:
dataset = webdataset.Webdataset(urls)
等价于如下的写法,内部处理对应的类是ShardList
,在示例中使用nodesplitter
和splitter
两个函数将URLs切分为多组shard:
urls = list(braceexpand.braceexpand("dataset-{000000..000999}.tar"))
dataset = wds.ShardList(urls, splitter=wds.split_by_worker, nodesplitter=wds.split_by_node, shuffle=False)
dataset = wds.Processor(dataset, wds.url_opener)
dataset = wds.Processor(dataset, wds.tar_file_expander)
dataset = wds.Processor(dataset, wds.group_by_keys)
def my_split_by_worker(urls):
wi = torch.utils.data.get_worker_info()
if wi is None:
return urls
else:
return urls[wi.id::wi.num_workers]
def my_split_by_node(urls):
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
return urls[node_id::node_count]
最简单示例如下,使用resample+with_epoch
dataset = wds.WebDataset(url, resampled=True).shuffle(1000).decode("rgb").to_tuple("png", "json").map(preprocess).with_epoch(10000)
sample = next(iter(dataset))
复杂的pipeline示例:
dataset = wds.DataPipeline(
wds.ResampledShards(url),
# at this point we have an iterator over all the shards
wds.tarfile_to_samples(),
wds.shuffle(1000),
wds.decode("torchrgb"),
# at this point, we have an list of decompressed training samples from each shard in this worker in sequence
get_patches, # note that can put iterator->iterator functions into the pipeline directly
wds.shuffle(10000),
wds.to_tuple("big.jpg", "json"),
wds.batched(16)
).with_epoch(10000)
batch = next(iter(loader))
batch[0].shape, batch[1].shape
还有一个with_length可以配合使用,用于指定数据集的总长度
两个实际中使用的例子:
在WebDataset文档中还介绍了使用ddp_equalize用于Multinode训练,但这种方式已经废弃, 底层实际还是采用with_epoch
和with_length
来实现,参考:ddp_equalize #194、IGNORE_test_ddp_equalize、ddp fixes