PyTorch 训练加速

PyTorch Dataloader 加速

 

参考源码:

https://github.com/NVIDIA/apex/blob/f5cd5ae937f168c763985f627bbf850648ea5f3f/examples/imagenet/main_amp.py#L256

class data_prefetcher():    def __init__(self, loader):        self.loader = iter(loader)        self.stream = torch.cuda.Stream()        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)        # With Amp, it isn't necessary to manually convert data to half.        # if args.fp16:        #     self.mean = self.mean.half()        #     self.std = self.std.half()        self.preload()
    def preload(self):        try:            self.next_input, self.next_target = next(self.loader)        except StopIteration:            self.next_input = None            self.next_target = None            return        with torch.cuda.stream(self.stream):            self.next_input = self.next_input.cuda(non_blocking=True)            self.next_target = self.next_target.cuda(non_blocking=True)            # With Amp, it isn't necessary to manually convert data to half.            # if args.fp16:            #     self.next_input = self.next_input.half()            # else:            self.next_input = self.next_input.float()            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

 

我们能看到 Nvidia 是在读取每次数据返回给网络的时候,预读取下一次迭代需要的数据,那么对我们自己的训练代码只需要做下面的改造:​​​​​​​

training_data_loader = DataLoader(    dataset=train_dataset,    num_workers=opts.threads,    batch_size=opts.batchSize,    pin_memory=True,    shuffle=True,)for iteration, batch in enumerate(training_data_loader, 1):    # 训练代码
改进后:
data, label = prefetcher.next()iteration = 0while data is not None:    iteration += 1    # 训练代码    data, label = prefetcher.next()还有其他方法:

1.把内存变成硬盘,把需要读的数据塞到里面去,这样加快了io。代码实例:mount -t tmpfs -o size=xxG tmpfs /your_path

2.上NVIDIA的dali模块,这样把一些预处理放到GPU上,加速一大波。

 

原理:https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789

你可能感兴趣的:(torch,深度学习)