PointGroup是CVPR2020关于点云实例分割的一篇文章,在ScanNet Benchmark上以AP50 63.6%的分数登上过榜首。论文的思路很简洁,首先由SparseConv+U-Net得到每个点的语义信息和每个实例的中心点,再通过ScoreNet和聚类分割出每个实例,具体内容请参考PointGroup论文和代码。
近来想在本地训练一下PointGroup,环境是Ubuntu18.04 + pytorch 1.2.0 + 16G RAM + Titan RTX,但是ScanNet数据量比较大,在训练过程中很容易内存不足,然后Ubuntu图形界面冻结,不得不强制充其电脑(如果不想损坏硬件,可以不用强制重启,等待系统把进程kill掉就会恢复正常,时间从几分钟到几个小时不等)。所以就研究了一下PointGroup的代码,作了一些修改,让模型可以在本地训练,在这里把遇到的问题和解决方法记录下来,供遇到同样问题的人参考。
ScanNet数据的加载和管理主要是在data/scannetv2_inst.py的Dataset类中,成员函数trainLoader()/ valLoader()/ testLoader()用来加载数据并生成DataLoader。由于PointGroup的input定制化程度比较高,源代码中在生成dataloader时候使用了自定义的collate_fn(batch生成函数)trainMerge()/ valMerge()/ testMerge(),同时并进行data augmentation。
class Dataset:
...
def trainLoader(self):
...
def valLoader(self):
...
def testLoader(self):
...
...
def trainMerge(self, id):
...
def valMerge(self, id):
...
def testMerge(self, id):
...
源代码的数据加载方式是一次性把所有data都加载到内存中,通过collate_fn处理数据生成input data,如果直接用源代码,至少需要32G内存才能进行正常的网络训练。
针对PointGroup的内存问题,主要考虑了两种方式来降低内存消耗。第一种是以训练速度为代价极小化降低内存使用(大约需要4G),第二种是在内存消耗和速度之间作了一个折衷(大约需要14G)。
考虑到训练时,每一个batch只需要batch size大小的数据,修改trainLoader()/ valLoader()/ testLoader(),使其生成dataloader时并不实际加载任何数据,只是生成batch。以训练时的过程为例,把数据加载放到trainMerge()中,在训练过程中调用trainMerge()来进行数据加载data augmentation。
class Dataset:
...
def trainLoader(self):
self.train_files = sorted(glob.glob(os.path.join(self.data_root,
self.dataset,
'train', '*' + self.filename_suffix)))
logger.info('Training samples: {}'.format(len(self.train_files)))
train_set = list(range(len(self.train_files)))
self.train_data_loader = DataLoader(train_set,
batch_size=self.batch_size,
collate_fn=lambda x: x,
num_workers=self.train_workers,
shuffle=True,
sampler=None,
drop_last=True,
pin_memory=True)
def valLoader(self):
...
def testLoader(self):
...
...
def trainMerge(self, id):
print("Entering collate_fn. id: ", id)
locs = []
locs_float = []
feats = []
labels = []
instance_labels = []
instance_infos = [] # (N, 9)
instance_pointnum = [] # (total_nInst), int
batch_offsets = [0]
total_inst_num = 0
for i, idx in enumerate(id):
xyz_origin, rgb, label, instance_label = torch.load(self.train_files[idx])
### jitter / flip x / rotation
xyz_middle = self.dataAugment(xyz_origin, True, True, True)
### scale
xyz = xyz_middle * self.scale
### elastic
xyz = self.elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50)
xyz = self.elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50)
### offset
xyz -= xyz.min(0)
### crop
xyz, valid_idxs = self.crop(xyz)
xyz_middle = xyz_middle[valid_idxs]
xyz = xyz[valid_idxs]
rgb = rgb[valid_idxs]
label = label[valid_idxs]
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
### get instance information
inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32))
inst_info = inst_infos["instance_info"] # (n, 9), (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
inst_pointnum = inst_infos["instance_pointnum"] # (nInst), list
instance_label[np.where(instance_label != -100)] += total_inst_num
total_inst_num += inst_num
### merge the scene to the batch
batch_offsets.append(batch_offsets[-1] + xyz.shape[0])
locs.append(torch.cat([torch.LongTensor(xyz.shape[0], 1).fill_(i), torch.from_numpy(xyz).long()], 1))
locs_float.append(torch.from_numpy(xyz_middle))
feats.append(torch.from_numpy(rgb) + torch.randn(3) * 0.1)
labels.append(torch.from_numpy(label))
instance_labels.append(torch.from_numpy(instance_label))
instance_infos.append(torch.from_numpy(inst_info))
instance_pointnum.extend(inst_pointnum)
### merge all the scenes in the batchd
batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
feats = torch.cat(feats, 0) # float (N, C)
labels = torch.cat(labels, 0).long() # long (N)
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
instance_infos = torch.cat(instance_infos, 0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None) # long (3)
### voxelize
voxel_locs, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(locs, self.batch_size, self.mode)
return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
def valMerge(self, id):
...
def testMerge(self, id):
...
这个方法使得训练原本需要的内存从~25GB降低到~4GB,由于dataloader在生成batch时是多线程,因此这种单线程加载方式会大大降低训练速度,在本地测试时,每个batch的训练时间增加大约1倍。代码已上传至GitHub。
为了利用dataloader的多线程优势,考虑在dataloader生成batch的时候加载数据,提高数据加载速度。修改代码如下:
class Dataset:
...
def trainLoader(self):
self.train_files = sorted(glob.glob(os.path.join(self.data_root,
self.dataset,
'train', '*' + self.filename_suffix)))
logger.info('Training samples: {}'.format(len(self.train_files)))
train_set = list(range(len(self.train_files)))
self.train_data_loader = DataLoader(train_set,
batch_size=self.batch_size,
collate_fn=self.trainMerge,
num_workers=self.train_workers,
shuffle=True,
sampler=None,
drop_last=True,
pin_memory=True)
def valLoader(self):
...
def testLoader(self):
...
...
def trainMerge(self, id):
print("Entering collate_fn. id: ", id)
locs = []
locs_float = []
feats = []
labels = []
instance_labels = []
instance_infos = [] # (N, 9)
instance_pointnum = [] # (total_nInst), int
batch_offsets = [0]
total_inst_num = 0
for i, idx in enumerate(id):
xyz_origin, rgb, label, instance_label = torch.load(self.train_files[idx])
### jitter / flip x / rotation
xyz_middle = self.dataAugment(xyz_origin, True, True, True)
### scale
xyz = xyz_middle * self.scale
### elastic
xyz = self.elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50)
xyz = self.elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50)
### offset
xyz -= xyz.min(0)
### crop
xyz, valid_idxs = self.crop(xyz)
xyz_middle = xyz_middle[valid_idxs]
xyz = xyz[valid_idxs]
rgb = rgb[valid_idxs]
label = label[valid_idxs]
instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
### get instance information
inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32))
inst_info = inst_infos["instance_info"] # (n, 9), (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
inst_pointnum = inst_infos["instance_pointnum"] # (nInst), list
instance_label[np.where(instance_label != -100)] += total_inst_num
total_inst_num += inst_num
### merge the scene to the batch
batch_offsets.append(batch_offsets[-1] + xyz.shape[0])
locs.append(torch.cat([torch.LongTensor(xyz.shape[0], 1).fill_(i), torch.from_numpy(xyz).long()], 1))
locs_float.append(torch.from_numpy(xyz_middle))
feats.append(torch.from_numpy(rgb) + torch.randn(3) * 0.1)
labels.append(torch.from_numpy(label))
instance_labels.append(torch.from_numpy(instance_label))
instance_infos.append(torch.from_numpy(inst_info))
instance_pointnum.extend(inst_pointnum)
### merge all the scenes in the batchd
batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
feats = torch.cat(feats, 0) # float (N, C)
labels = torch.cat(labels, 0).long() # long (N)
instance_labels = torch.cat(instance_labels, 0).long() # long (N)
instance_infos = torch.cat(instance_infos, 0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None) # long (3)
### voxelize
voxel_locs, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(locs, self.batch_size, self.mode)
return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
def valMerge(self, id):
...
def testMerge(self, id):
...
此方法训练时所需内存~14GB,但是训练速度基本没有降低。
以上方法讲的比较粗略,需要对pytorch的DataLoader有一定的了解,并且熟悉其参数collate_fn的用法。文中提及的内存数值为通过观察System Monitor中内存变化估算出来的,并不代表准确内存消耗,望注意。
第一个方法的原理是彻底减少data在内存中的驻留,用到什么就加载什么,用性能换取内存。第二种方法中,dataloader加载的batch会缓存在内存中,但是减少了Dataset类所需的内存。
鉴于本人知识水平有限,如果大家发现错误或者有好的改进方法,还请交流指正。