PointGroup内存消耗太大问题

目录

  • 1. 前言
  • 2. 解决方法
    • 2.1 PointGroup源代码dataset加载方式
    • 2.2 以速度为代价减少内存消耗
      • 2.2.1 极小化降低内存使用方法
      • 2.2.2 平衡内存消耗和速度方法
  • 3. 总结

1. 前言

PointGroup是CVPR2020关于点云实例分割的一篇文章,在ScanNet Benchmark上以AP50 63.6%的分数登上过榜首。论文的思路很简洁,首先由SparseConv+U-Net得到每个点的语义信息和每个实例的中心点,再通过ScoreNet和聚类分割出每个实例,具体内容请参考PointGroup论文和代码。

近来想在本地训练一下PointGroup,环境是Ubuntu18.04 + pytorch 1.2.0 + 16G RAM + Titan RTX,但是ScanNet数据量比较大,在训练过程中很容易内存不足,然后Ubuntu图形界面冻结,不得不强制充其电脑(如果不想损坏硬件,可以不用强制重启,等待系统把进程kill掉就会恢复正常,时间从几分钟到几个小时不等)。所以就研究了一下PointGroup的代码,作了一些修改,让模型可以在本地训练,在这里把遇到的问题和解决方法记录下来,供遇到同样问题的人参考。

2. 解决方法

2.1 PointGroup源代码dataset加载方式

ScanNet数据的加载和管理主要是在data/scannetv2_inst.py的Dataset类中,成员函数trainLoader()/ valLoader()/ testLoader()用来加载数据并生成DataLoader。由于PointGroup的input定制化程度比较高,源代码中在生成dataloader时候使用了自定义的collate_fn(batch生成函数)trainMerge()/ valMerge()/ testMerge(),同时并进行data augmentation。

class Dataset:
    ...
    def trainLoader(self):
        ...
    def valLoader(self):
        ...
    def testLoader(self):
        ...
    ...
    def trainMerge(self, id):
        ...
    def valMerge(self, id):
        ...
    def testMerge(self, id):
        ...

源代码的数据加载方式是一次性把所有data都加载到内存中,通过collate_fn处理数据生成input data,如果直接用源代码,至少需要32G内存才能进行正常的网络训练。

2.2 以速度为代价减少内存消耗

针对PointGroup的内存问题,主要考虑了两种方式来降低内存消耗。第一种是以训练速度为代价极小化降低内存使用(大约需要4G),第二种是在内存消耗和速度之间作了一个折衷(大约需要14G)。

2.2.1 极小化降低内存使用方法

考虑到训练时,每一个batch只需要batch size大小的数据,修改trainLoader()/ valLoader()/ testLoader(),使其生成dataloader时并不实际加载任何数据,只是生成batch。以训练时的过程为例,把数据加载放到trainMerge()中,在训练过程中调用trainMerge()来进行数据加载data augmentation。

class Dataset:
    ...
    def trainLoader(self):
        self.train_files = sorted(glob.glob(os.path.join(self.data_root,
                                           self.dataset, 
                                           'train', '*' + self.filename_suffix)))

        logger.info('Training samples: {}'.format(len(self.train_files)))

        train_set = list(range(len(self.train_files)))
        self.train_data_loader = DataLoader(train_set, 
                                            batch_size=self.batch_size, 
                                            collate_fn=lambda x: x, 
                                            num_workers=self.train_workers,
                                            shuffle=True, 
                                            sampler=None, 
                                            drop_last=True, 
                                            pin_memory=True)
    
    def valLoader(self):
        ...
    def testLoader(self):
        ...
    ...
    def trainMerge(self, id):
        print("Entering collate_fn. id: ", id)
        locs = []
        locs_float = []
        feats = []
        labels = []
        instance_labels = []

        instance_infos = []  # (N, 9)
        instance_pointnum = []  # (total_nInst), int

        batch_offsets = [0]

        total_inst_num = 0
        for i, idx in enumerate(id):
            xyz_origin, rgb, label, instance_label = torch.load(self.train_files[idx])

            ### jitter / flip x / rotation
            xyz_middle = self.dataAugment(xyz_origin, True, True, True)

            ### scale
            xyz = xyz_middle * self.scale

            ### elastic
            xyz = self.elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50)
            xyz = self.elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50)

            ### offset
            xyz -= xyz.min(0)

            ### crop
            xyz, valid_idxs = self.crop(xyz)

            xyz_middle = xyz_middle[valid_idxs]
            xyz = xyz[valid_idxs]
            rgb = rgb[valid_idxs]
            label = label[valid_idxs]
            instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)

            ### get instance information
            inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32))
            inst_info = inst_infos["instance_info"]  # (n, 9), (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
            inst_pointnum = inst_infos["instance_pointnum"]   # (nInst), list

            instance_label[np.where(instance_label != -100)] += total_inst_num
            total_inst_num += inst_num

            ### merge the scene to the batch
            batch_offsets.append(batch_offsets[-1] + xyz.shape[0])

            locs.append(torch.cat([torch.LongTensor(xyz.shape[0], 1).fill_(i), torch.from_numpy(xyz).long()], 1))
            locs_float.append(torch.from_numpy(xyz_middle))
            feats.append(torch.from_numpy(rgb) + torch.randn(3) * 0.1)
            labels.append(torch.from_numpy(label))
            instance_labels.append(torch.from_numpy(instance_label))

            instance_infos.append(torch.from_numpy(inst_info))
            instance_pointnum.extend(inst_pointnum)

        ### merge all the scenes in the batchd
        batch_offsets = torch.tensor(batch_offsets, dtype=torch.int)  # int (B+1)

        locs = torch.cat(locs, 0)                                # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
        locs_float = torch.cat(locs_float, 0).to(torch.float32)  # float (N, 3)
        feats = torch.cat(feats, 0)                              # float (N, C)
        labels = torch.cat(labels, 0).long()                     # long (N)
        instance_labels = torch.cat(instance_labels, 0).long()   # long (N)

        instance_infos = torch.cat(instance_infos, 0).to(torch.float32)       # float (N, 9) (meanxyz, minxyz, maxxyz)
        instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int)  # int (total_nInst)

        spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None)     # long (3)

        ### voxelize
        voxel_locs, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(locs, self.batch_size, self.mode)

        return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
                'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
                'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
                'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
                
    def valMerge(self, id):
        ...
    def testMerge(self, id):
        ...

这个方法使得训练原本需要的内存从~25GB降低到~4GB,由于dataloader在生成batch时是多线程,因此这种单线程加载方式会大大降低训练速度,在本地测试时,每个batch的训练时间增加大约1倍。代码已上传至GitHub。

2.2.2 平衡内存消耗和速度方法

为了利用dataloader的多线程优势,考虑在dataloader生成batch的时候加载数据,提高数据加载速度。修改代码如下:

class Dataset:
    ...
    def trainLoader(self):
        self.train_files = sorted(glob.glob(os.path.join(self.data_root,
                                           self.dataset, 
                                           'train', '*' + self.filename_suffix)))

        logger.info('Training samples: {}'.format(len(self.train_files)))

        train_set = list(range(len(self.train_files)))
        self.train_data_loader = DataLoader(train_set, 
                                            batch_size=self.batch_size, 
                                            collate_fn=self.trainMerge, 
                                            num_workers=self.train_workers,
                                            shuffle=True, 
                                            sampler=None, 
                                            drop_last=True, 
                                            pin_memory=True)
    
    def valLoader(self):
        ...
    def testLoader(self):
        ...
    ...
    def trainMerge(self, id):
        print("Entering collate_fn. id: ", id)
        locs = []
        locs_float = []
        feats = []
        labels = []
        instance_labels = []

        instance_infos = []  # (N, 9)
        instance_pointnum = []  # (total_nInst), int

        batch_offsets = [0]

        total_inst_num = 0
        for i, idx in enumerate(id):
            xyz_origin, rgb, label, instance_label = torch.load(self.train_files[idx])

            ### jitter / flip x / rotation
            xyz_middle = self.dataAugment(xyz_origin, True, True, True)

            ### scale
            xyz = xyz_middle * self.scale

            ### elastic
            xyz = self.elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50)
            xyz = self.elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50)

            ### offset
            xyz -= xyz.min(0)

            ### crop
            xyz, valid_idxs = self.crop(xyz)

            xyz_middle = xyz_middle[valid_idxs]
            xyz = xyz[valid_idxs]
            rgb = rgb[valid_idxs]
            label = label[valid_idxs]
            instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)

            ### get instance information
            inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32))
            inst_info = inst_infos["instance_info"]  # (n, 9), (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
            inst_pointnum = inst_infos["instance_pointnum"]   # (nInst), list

            instance_label[np.where(instance_label != -100)] += total_inst_num
            total_inst_num += inst_num

            ### merge the scene to the batch
            batch_offsets.append(batch_offsets[-1] + xyz.shape[0])

            locs.append(torch.cat([torch.LongTensor(xyz.shape[0], 1).fill_(i), torch.from_numpy(xyz).long()], 1))
            locs_float.append(torch.from_numpy(xyz_middle))
            feats.append(torch.from_numpy(rgb) + torch.randn(3) * 0.1)
            labels.append(torch.from_numpy(label))
            instance_labels.append(torch.from_numpy(instance_label))

            instance_infos.append(torch.from_numpy(inst_info))
            instance_pointnum.extend(inst_pointnum)

        ### merge all the scenes in the batchd
        batch_offsets = torch.tensor(batch_offsets, dtype=torch.int)  # int (B+1)

        locs = torch.cat(locs, 0)                                # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
        locs_float = torch.cat(locs_float, 0).to(torch.float32)  # float (N, 3)
        feats = torch.cat(feats, 0)                              # float (N, C)
        labels = torch.cat(labels, 0).long()                     # long (N)
        instance_labels = torch.cat(instance_labels, 0).long()   # long (N)

        instance_infos = torch.cat(instance_infos, 0).to(torch.float32)       # float (N, 9) (meanxyz, minxyz, maxxyz)
        instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int)  # int (total_nInst)

        spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None)     # long (3)

        ### voxelize
        voxel_locs, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(locs, self.batch_size, self.mode)

        return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
                'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
                'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
                'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
                
    def valMerge(self, id):
        ...
    def testMerge(self, id):
        ...

此方法训练时所需内存~14GB,但是训练速度基本没有降低。

3. 总结

以上方法讲的比较粗略,需要对pytorch的DataLoader有一定的了解,并且熟悉其参数collate_fn的用法。文中提及的内存数值为通过观察System Monitor中内存变化估算出来的,并不代表准确内存消耗,望注意。
第一个方法的原理是彻底减少data在内存中的驻留,用到什么就加载什么,用性能换取内存。第二种方法中,dataloader加载的batch会缓存在内存中,但是减少了Dataset类所需的内存。
鉴于本人知识水平有限,如果大家发现错误或者有好的改进方法,还请交流指正。

你可能感兴趣的:(PointCloud,点云,分割,pytorch,深度学习,机器学习,神经网络,python)