【Pytorch】nvidia-dali——一种加速数据增强的方法

目的

问题: 当我们使用pytorch训练小模型或者使用较大batch size的时候会发现GPU利用率很低,训练周期比较长。其原因之一是在dataloader加载数据之后在cpu上做一些数据增强的操作(eg.resize、crop等),比较耗时,导致很多时候都是GPU在等CPU的数据,造成了严重的浪费。
解决: 使用nvidia-dali将一些cpu上的数据预处理操作放到gpu上去处理,可以极大的提高训练的效率.
缺点: 好像只提供了固定的几种格式的数据,ImageNet数据格式(分类)、COCO数据集格式(检测)

实现

  1. 使用DALI封装的数据加载代码(暂时看不懂,可以先看官方文档的Install、Getting started)
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIterator
 
 
class HybridTrainPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False, local_rank=0, world_size=1):
        super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        dali_device = "gpu"
        self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True)
        self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
        self.res = ops.RandomResizedCrop(device="gpu", size=crop, random_area=[0.08, 1.25])
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            image_type=types.RGB,
                                            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                            std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
        self.coin = ops.CoinFlip(probability=0.5)
        print('DALI "{0}" variant'.format(dali_device))
 
    def define_graph(self):
        rng = self.coin()
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images, mirror=rng)
        return [output, self.labels]
 
 
class HybridValPipe(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, local_rank=0, world_size=1):
        super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size,
                                    random_shuffle=False)
        self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
        self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)
        self.cmnp = ops.CropMirrorNormalize(device="gpu",
                                            output_dtype=types.FLOAT,
                                            output_layout=types.NCHW,
                                            crop=(crop, crop),
                                            image_type=types.RGB,
                                            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                            std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
 
    def define_graph(self):
        self.jpegs, self.labels = self.input(name="Reader")
        images = self.decode(self.jpegs)
        images = self.res(images)
        output = self.cmnp(images)
        return [output, self.labels]
 
 
def get_imagenet_iter_dali(type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, val_size=256,
                           world_size=1,
                           local_rank=0):
    if type == 'train':
        pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank,
                                    data_dir=image_dir + '/train',
                                    crop=crop, world_size=world_size, local_rank=local_rank)
        pip_train.build()
        dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader") // world_size)
        return dali_iter_train
    elif type == 'val':
        pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank,
                                data_dir=image_dir + '/val',
                                crop=crop, size=val_size, world_size=world_size, local_rank=local_rank)
        pip_val.build()
        dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader") // world_size)
        return dali_iter_val
if __name__ == '__main__':
    train_loader = get_imagenet_iter_dali(type='train', image_dir='/userhome/memory_data/imagenet', batch_size=256,
                                          num_threads=4, crop=224, device_id=0, num_gpus=1)
    print('start iterate')
    start = time.time()
    for i, data in enumerate(train_loader):
        images = data[0]["data"].cuda(non_blocking=True)
        labels = data[0]["label"].squeeze().long().cuda(non_blocking=True)
    end = time.time()
    print('end iterate')
    print('dali iterate time: %fs' % (end - start))
  1. 使用DALI自定义数据加载类代码(参考官方文档的Tutorials/Data Loading/ ExternalSource operator)
from __future__ import division
import torch
import types
import joblib
import collections
import numpy as np
import pandas as pd
from random import shuffle
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import nvidia.dali.plugin.pytorch as dalitorch
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator
 
 
 
def grid2x2(img):
	#自定义函数
    h, w, c = img.shape
    left_top = img[0:h//2, 0:w//2, :]
    left_bottom = img[h//2:h, 0:w//2, :]
    right_top = img[0:h//2, w//2:w, :]
    right_bottom = img[h//2:h, w//2:w, :]
    return left_top, right_top, left_bottom, left_bottom
 
 
 
class ExternalInputIterator(object):
	#自定义迭代器
    def __init__(self, images_dir, txt_path, batch_size, device_id, num_gpus):
        self.images_dir = images_dir
        self.batch_size = batch_size
        with open(txt_path, 'r') as f:
            self.files = [line.rstrip() for line in f if line is not '']
        
        # whole data set size
        self.data_set_len = len(self.files)
        # based on the device_id and total number of GPUs - world size
        # get proper shard
        self.files = self.files[self.data_set_len * device_id // num_gpus:
                                self.data_set_len * (device_id + 1) // num_gpus]
        self.n = len(self.files)
 
    def __iter__(self):
        self.i = 0
        shuffle(self.files)
        return self
 
    def __next__(self):
        batch = []
        labels = []
 
        if self.i >= self.n:
            raise StopIteration
 
        for _ in range(self.batch_size):
            jpeg_filename, label = self.files[self.i].split(',')
            f = open(self.images_dir + jpeg_filename, 'rb')
            batch.append(np.frombuffer(f.read(), dtype = np.uint8))
            labels.append(np.array([int(label)], dtype = np.uint8))
            self.i = (self.i + 1) % self.n
        return (batch, labels)
 
    @property
    def size(self,):
        return self.data_set_len
 
    next = __next__
 
 
class ExternalSourcePipeline(Pipeline):
	#自定义数据增强操作(通过iter_setup函数将迭代器的数据送入define_graph函数处理)
    def __init__(self, resize, batch_size, num_threads, device_id, external_data):
        super(ExternalSourcePipeline, self).__init__(batch_size,
                                      num_threads,
                                      device_id,
                                      seed=12,
                                      exec_async=False,
                                      exec_pipelined=False,
                                    )
        self.input = ops.ExternalSource()
        self.input_label = ops.ExternalSource()
        self.decode = ops.ImageDecoder(device = "cpu", output_type = types.RGB)
        
        #自定义的函数只能在cpu上运行
        self.grid = ops.PythonFunction(function=grid2x2, num_outputs=4)	
        self.resize = ops.Resize(device="gpu", 
                                 resize_x=resize, 
                                 resize_y=resize,
                                 interp_type=types.INTERP_LINEAR)
        self.external_data = external_data
        self.iterator = iter(self.external_data)
 
 
 
    def define_graph(self):
        self.jpegs = self.input()
        self.labels = self.input_label()
        images = self.decode(self.jpegs)
        
        images1, images2, images3, images4 = self.grid(images)
        images = self.resize(images.gpu())
        images1 = self.resize(images1.gpu())
        images2 = self.resize(images2.gpu())
        images3 = self.resize(images3.gpu())
        images4 = self.resize(images4.gpu())
        return (images, images1, images2, images3, images4, self.labels)
 
    def iter_setup(self):
        try:
            images, labels = self.iterator.next()
            self.feed_input(self.jpegs, images)
            self.feed_input(self.labels, labels)
        except StopIteration:
            self.iterator = iter(self.external_data)
            raise StopIteration
 

def create_dataloder(img_dir, 
                     txt_path, 
                     resize,
                     batch_size,
                     device_id=0,
                     num_gpus=1,
                     num_threads=6):
    eii = ExternalInputIterator(img_dir,
                                txt_path, 
                                batch_size=batch_size, 
                                device_id=device_id,
                                num_gpus=num_gpus)
    pipe = ExternalSourcePipeline(resize=resize,
                                  batch_size=batch_size, 
                                  num_threads=num_threads, 
                                  device_id = 0,
                                  external_data = eii)
 
    pii = PyTorchIterator(pipe, 
                          output_map=["data0", "data1", "data2", "data3", "data4", "label"], 
                          size=eii.size, 
                          last_batch_padded=True, 
                          fill_last_batch=False)
 
    return pii
 
 
if __name__ == '__main__':
    batch_size = 32
    num_gpus = 1
    num_threads = 8
    epochs = 1
 
    pii = create_dataloder('img_path',
                            resize=224,
                            batch_size=batch_size,
                            txt_path='file_path',
                            )
 
 
    for e in range(epochs):
        for i, data in enumerate(pii):
            imgs = data[0]["data4"]
            labels = data[0]["label"]
            print("epoch: {}, iter {}".format(e, i), imgs.shape, labels.shape)
 
        pii.reset()

注,分类任务的数据准备最好按照ImageNet的数据格式,检测任务的数据准备参考检测数据格式


参考1:DALI官方文档(记录一下常用)
目录:

  • Installation:安装DALI命令

  • Getting started:入门的简单案例(分类任务)

  • Tutorials
     General:
      1. Data Loading:
       1.1 ExternalSource operator:自定义数据加载操作(ExternalInputIterator、ExternalSourcePipeline);
       1.2 COCO Reader:COCO数据格式读取(检测任务)
      2. DALI expressions and arithmetic operations:tensor上自定义+ - * /操作
      3. Multiple GPU support:GPU上进行数据增强操作(shard_id:显卡id, num_shards:将数据分成几份)
      4. Normalize operator :正则化
     Image Processing:一些图片处理上的常用操作(Decoder的CPU/Hybrid)
     Use Cases:一些demo(包括用于分类任务和检测任务)

  • Framework integration:DALI在常用框架(Pyotch、tf)的使用

  • Supported operations:DALI中封装的所有函数的使用

参考2:nvidia-dali GPU加速预处理
参考3:pytorch 一种加速dataloder的方法
参考4:NIVIDIA/DALI的github

你可能感兴趣的:(Pytorch,gpu,数据增强)