PointNet

PointNet介绍

参考博客:https://www.jianshu.com/p/6a0fc51187c1
https://blog.csdn.net/qq_15332903/article/details/80224387

背景

在以往的深度学习处理点云时,往往将其转换成为特定视角下的深度图像或者体素。但是这样在遇到点云数据规模较大的情况时,计算量是极大的。因此论文提出一种新的方案,能够直接对输入点云进行处理。

解决的问题

  1. 问题一:点云的无序性
    虽然输入的点云是有顺序的,但是顺序不应该影响最终的识别/切割结果。
    解决方案:论文提出了三种解决思路,第一种思路是将点云进行一定的顺序进行排序;第二种思路用数据所有的排序进行数据增强然后使用RNN模型;第三种思路是用对称函数来保证排列不变性。
    但是对于第一种方案,当处于高维空间中,很难找到一个比较稳定的排序使得网络能够学的从输入到输出的一致映射;对于第二种方案,当点云规模变大时,使用RNN进行训练不但耗费的资源较大,并且秩序的鲁棒性是较差的。
    因此作者选择第三种方案,精简的使用maxpooling这个对称函数来提取点云数据特征。本文利用了多个MLP来进行函数拟合,最后使用maxpooling(函数g)进行最终拟合表示,保证了模型对特定空间转换的不变性。
    PointNet_第1张图片

  2. 问题二:点云局部和全局信息
    点云的数据不应该是杂乱无序的,点云点与点之间的排序是包含点云信息的,即存在空间关系。
    解决方案:论文提出将局部特征和全局特征进行串联来进行信息的聚合。

  3. 问题三:点云的不变性
    点云数据所代表的的目标对于空间转换应该具备有不变性特征,比如旋转不变性或者平移不变性等。
    解决方案:论文提出在进行特征提取之前,先对点云数据进行对其的方式来保证不变性。因此引入一个小型网络T-Net得到相应的转换矩阵,使它与输入点云数据相乘来实现。

网络架构

PointNet_第2张图片

  1. T-Net:根据论文后面对小型网络T-Net的信息补充可以知道,T-Net由MLP(64,128,1024),然后通过一层maxpooling池化,最后通过两层全连接层(512,256)进行连接。其中,除去最后一层全连接层没有Relu激活和批量标准正交化,其余都有。
    网络在一共做了两次T-Net的转置变化,第一次是在空间上对点云的顺序进行调整,旋转出一个更有利于欧分类或者分割的角度,第二次是对提取处理的64维特征进行对齐,即在特征层面上对点云进行变换。
    特征孔径的变换矩阵维数远高于空间变换矩阵,大大增加优化收敛的难度。因此对特征矩阵进行进一步的约束,将变换矩阵约束为接近正交矩阵:
    正交变换
  2. MLP:MLP的第一层卷积核大小为1x3,之后每一层的卷积核的大小都是1x1。
  3. 局部和全局信息的整合:见图中橙色部分的网络,对应为分割的部分网络图。首先采用之前的nx64的部分网络,并且使卷积池化完成得到的1x1024数据进行扩充成nx1024,并且合并在nx64的数据的尾部,组成新的特征。这样便将局部特征和全局特征进行串联以来聚合信息。

进一步的理论分析

  1. 定理一:
    PointNet_第3张图片
    Hausdorff距离介绍:https://zhuanlan.zhihu.com/p/351921396
    证明了PointNet的网络结构能够拟合任意的连续集合函数。其作用类似证明神经网络能够拟合任意连续函数一样。同时,作者发现PointNet模型的表征能力和maxpooling操作输出的数据维度(K)相关,K值越大,模型的表征能力越强。

  2. 定理二:
    PointNet_第4张图片
    定理(a):说明对于任何输入数据集S,都存在一个最小集Cs和一个最大集Ns,使得对Cs和Ns之间的任何集合T,其网络输出都和S一样。这也就是说,模型对输入数据在有噪声(引入额外的数据点,趋于Ns)和有数据损坏(缺少数据点,趋于Cs)的情况都是鲁棒的。
    定理(b):说明了最小集Cs的数据多少由maxpooling操作输出数据的维度K给出上界。换个角度来讲,PointNet能够总结出表示某类物体形状的关键点,基于这些关键点PointNet能够判别物体的类别。这样的能力决定了PointNet对噪声和数据缺失的鲁棒性。

代码分析

参考开源代码:https://github.com/fxia22/pointnet.pytorch
本次开源代码使用的分类数据集为ModelNet40,分割数据集为ShapeNet。具体数据集的信息参考博客。

数据预处理

  • ShapeNet
class ShapeNetDataset(data.Dataset):
    def __init__(self,
                 root,
                 npoints=2500,
                 classification=False,
                 class_choice=None,
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints
        self.root = root
        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')	#ShapeNet里面各个文件名字存储编号,例如:Airplane	02691156
        self.cat = {}
        self.data_augmentation = data_augmentation
        self.classification = classification
        self.seg_classes = {}
        
        with open(self.catfile, 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]			#名称与编号文件夹对应
        #print(self.cat)
        if not class_choice is None:
            self.cat = {k: v for k, v in self.cat.items() if k in class_choice}

        self.id2cat = {v: k for k, v in self.cat.items()}	#反向编号对应名称

        self.meta = {}
        splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))
        #from IPython import embed; embed()
        filelist = json.load(open(splitfile, 'r'))
        for item in self.cat:				#文件编号录入
            self.meta[item] = []

        for file in filelist:
            _, category, uuid = file.split('/')
            if category in self.cat.values():				#同名称数据汇总
                self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),
                                        os.path.join(self.root, category, 'points_label', uuid+'.seg')))

        self.datapath = []
        for item in self.cat:
            for fn in self.meta[item]:
                self.datapath.append((item, fn[0], fn[1]))	#名称,数据,label

        self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
        print(self.classes)
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:				#每种模型切割块数量信息获取
            for line in f:
                ls = line.strip().split()
                self.seg_classes[ls[0]] = int(ls[1])
        self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
        print(self.seg_classes, self.num_seg_classes)

    def __getitem__(self, index):
        fn = self.datapath[index]
        cls = self.classes[self.datapath[index][0]]
        point_set = np.loadtxt(fn[1]).astype(np.float32)
        seg = np.loadtxt(fn[2]).astype(np.int64)
        #print(point_set.shape, seg.shape)

        choice = np.random.choice(len(seg), self.npoints, replace=True)
        #resample
        point_set = point_set[choice, :]

        point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
        point_set = point_set / dist #scale

        if self.data_augmentation:
            theta = np.random.uniform(0,np.pi*2)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
            point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
            point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter

        seg = seg[choice]
        point_set = torch.from_numpy(point_set)
        seg = torch.from_numpy(seg)
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))

        if self.classification:
            return point_set, cls
        else:
            return point_set, seg

    def __len__(self):
        return len(self.datapath)

ModelNet40

class ModelNetDataset(data.Dataset):
    def __init__(self,
                 root,
                 npoints=2500,
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints
        self.root = root
        self.split = split
        self.data_augmentation = data_augmentation
        self.fns = []
        with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f:
            for line in f:
                self.fns.append(line.strip())

        self.cat = {}
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = int(ls[1])

        print(self.cat)
        self.classes = list(self.cat.keys())

    def __getitem__(self, index):
        fn = self.fns[index]
        cls = self.cat[fn.split('/')[0]]
        with open(os.path.join(self.root, fn), 'rb') as f:
            plydata = PlyData.read(f)
        pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T
        choice = np.random.choice(len(pts), self.npoints, replace=True)
        point_set = pts[choice, :]

        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # center
        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
        point_set = point_set / dist  # scale

        if self.data_augmentation:
            theta = np.random.uniform(0, np.pi * 2)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotation
            point_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitter

        point_set = torch.from_numpy(point_set.astype(np.float32))
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))
        return point_set, cls


    def __len__(self):
        return len(self.fns)

T-Net (MLP(64,128,1024)+maxpooling+full connect(512,256))

class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x

Cls与Seg前半段共同的网络

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)	#MLP(64,128,1024)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

Cls网络

class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat

Seg网络

class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat

你可能感兴趣的:(3D目标检测论文,深度学习,pytorch)