当我们使用pytorch训练小模型的时候会发现GPU利用率很低,训练速度非常慢,profile发现预处理速度很慢,很多时候都是GPU在等CPU的数据,造成了严重的浪费,而dali就是利用GPU进行预处理,可以极大的提高训练的效率.
pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/cuda/10.0 nvidia-dali
DALI(NVIDIA Data Loading Library)是高度优化用来加速计算机视觉深度学习应用的执行引擎。目前典型的深度学习框架提供了两种预处理的流水线:
1.快速但是不灵活,使用C++编写,并且导出python的接口,只有一些典型的数据集有对应的预处理被导出
2.慢但是灵活,可由C++或python编写,并且可以用来组合任意的数据数据输入流水线,但是会比较慢。最大的问题在于python的全局线程锁,这迫使开发者不得不使用多进程的方式,使得输入流水线变得很复杂。
DALI提供了以上两种方式的支持,开发者不仅可以使用简单高效的引擎,还可以自定义处理,使能了GPU的离线加载。
建议做成ImageNet格式的数据集,也就是两层文件的加载方式。
第一层子文件夹是要识别的所有类别,其下子文件夹再存储每个类别对应的图片。
在DALI中最重要的类是Pipeline, 它包含了所有必须的信息以及和定义、构建和运行流水线相关的多个函数,所有的流水线都要从其继承而来。
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIterator
class HybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, dali_cpu=False, local_rank=0, world_size=1):
super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
dali_device = "gpu"
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True)
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.res = ops.RandomResizedCrop(device="gpu", size=crop, random_area=[0.08, 1.25])
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
image_type=types.RGB,
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
self.coin = ops.CoinFlip(probability=0.5)
print('DALI "{0}" variant'.format(dali_device))
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images, mirror=rng)
return [output, self.labels]
class HybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, local_rank=0, world_size=1):
super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size,
random_shuffle=False)
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.res = ops.Resize(device="gpu", resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
def define_graph(self):
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]
def get_imagenet_iter_dali(type, image_dir, batch_size, num_threads, device_id, num_gpus, crop, val_size=256,
world_size=1,
local_rank=0):
if type == 'train':
pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank,
data_dir=image_dir + '/train',
crop=crop, world_size=world_size, local_rank=local_rank)
pip_train.build()
dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader") // world_size)
return dali_iter_train
elif type == 'val':
pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=local_rank,
data_dir=image_dir + '/val',
crop=crop, size=val_size, world_size=world_size, local_rank=local_rank)
pip_val.build()
dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader") // world_size)
return dali_iter_val
if __name__ == '__main__':
train_loader = get_imagenet_iter_dali(type='train', image_dir='/userhome/memory_data/imagenet', batch_size=256,
num_threads=4, crop=224, device_id=0, num_gpus=1)
print('start iterate')
start = time.time()
for i, data in enumerate(train_loader):
images = data[0]["data"].cuda(non_blocking=True)
labels = data[0]["label"].squeeze().long().cuda(non_blocking=True)
end = time.time()
print('end iterate')
print('dali iterate time: %fs' % (end - start))