mmlab提高GPU利用率

问题:提高基于mmlab框架的代码GPU利用率,可以改哪些地方?

以mmpretrain为例:

1、数据处理用gpu版代替

举例:重写data_preprocessor。
以下是重写了mmpretrain中的SelfSupDataPreprocessor,config中flip、ColorJitter、GaussianBlur等cpu上进行的图像变换改用torchvision.transforms 方法实现,支持gpu和batch操作。

from torchvision import transforms as T
from mmpretrain.registry import MODELS
from mmpretrain.models.utils import SelfSupDataPreprocessor
@MODELS.register_module()
class SelfSupDataGPUPreprocessor(SelfSupDataPreprocessor):
    """Image pre-processor for operations, like normalization and bgr to rgb.

    Compared with the :class:`mmengine.ImgDataPreprocessor`, this module
    supports ``inputs`` as torch.Tensor or a list of torch.Tensor.
    """

    def transform(self):
        trans = torch.nn.Sequential(
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
            T.RandomGrayscale(p=0.2),
            T.GaussianBlur(kernel_size=(3, 7)),
            T.ConvertImageDtype(torch.float),
        )
        if self._enable_normalize:
            trans.add_module('norm',
                             T.Normalize((0.485, 0.456, 0.406),
                                         (0.229, 0.224, 0.225)))  # rgb
        return trans

    def forward(
            self,
            data: dict,
            training: bool = False
    ) -> Tuple[List[torch.Tensor], Optional[list]]:
        """Performs normalization and bgr2rgb conversion based on
        ``BaseDataPreprocessor``.

        Args:
            data (dict): data sampled from dataloader.
            training (bool): Whether to enable training time augmentation. If
                subclasses override this method, they can perform different
                preprocessing strategies for training and testing based on the
                value of ``training``.
        Returns:
            Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
            model input.
        """
        assert isinstance(data,
                          dict), 'Please use default_collate in dataloader, \
            instead of pseudo_collate.'

        data = [val for _, val in data.items()]
        batch_inputs, batch_data_samples = self.cast_data(data)

        # Here is what is different from :class:`mmengine.ImgDataPreprocessor`
        # Since there are multiple views for an image for some algorithms,
        # e.g. SimCLR, each item in inputs is a list, containing multi-views
        # for an image.
        if isinstance(batch_inputs, list):
            # channel transform
            if self._channel_conversion:
                batch_inputs = [
                    _input[:, [2, 1, 0], ...] for _input in batch_inputs
                ]
            # transform
            batch_inputs = [
                self.transform()(_input) for _input in batch_inputs
            ]

        else:
            # channel transform
            if self._channel_conversion:
                batch_inputs = batch_inputs[:, [2, 1, 0], ...]
            # transform
            batch_inputs = self.transform()(batch_inputs)

        return {'inputs': batch_inputs, 'data_samples': batch_data_samples}

2、train_dataloader设置

设置num_workers、pin_memory、persistent_workers、prefetch_factor等字段

train_dataloader = dict(
    batch_size=512,
    num_workers=16,
    pin_memory=True, # 锁页内存,设为True减少cpu和gpu数据互相拷贝时间
    persistent_workers=True, # 设为True,epoch结束后不会关闭,减少新epoch起始重启worker时间
    prefetch_factor=32, # prefetch_factor是预加载数据,pytorch1.7版本以上支持,一般batch_size=num_workers*prefetch_factor。
    sampler=dict(type='DefaultSampler', shuffle=True),
    collate_fn=dict(type='default_collate'),
    dataset=dict(
        type=dataset_type,
        ann_file='xx.txt',
        pipeline=train_pipeline))

3、env_cfg设置

env_cfg = dict(
    cudnn_benchmark=True, # 优化卷积计算
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=32), # ubuntu一般默认是fork,比spawn快一些
    dist_cfg=dict(backend='nccl'))

cudnn_benchmark=True的前提最好是网络结构不会频繁动态变化,输入尺寸也固定的情况。否则计算过程中可能会多次调用资源寻找最优卷积计算,反而减慢了运行速度。

你可能感兴趣的:(mmlab,mmcv,gpu利用率,pytorch)