PointNet源码解读

  • 本次源码解读的地址为:https://github.com/yanx27/Pointnet_Pointnet2_pytorch,这一版本的源码易读性高,主要是封装程度较低,注释较全,安装额外的库也比较少。

PointNet源码解读_第1张图片

Pipeline

前向过程:数据加载、数据增强

 for epoch in range(start_epoch, args.epoch):
     log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
     mean_correct = []
     classifier = classifier.train()
     scheduler.step()
     for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader),smoothing=0.9):
         optimizer.zero_grad()
         # b x n x c
         points = points.data.numpy()
         # b x n x c:并不缩减point的数量,而是设定某个阈值,将小于某个阈值的点覆盖为第一个点的信息
         points = provider.random_point_dropout(points)
         # 针对不同的batch进行随机坐标缩放,将某batch所有的pc乘某个缩放系数
         points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
         # 针对不同的batch进行随机坐标缩放,将某batch所有的pc同步进行坐标移动
         points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
         points = torch.Tensor(points)
         points = points.transpose(2, 1) # b,c,n:conv1d是channel first因此要转换
         if not args.use_cpu:
             points, target = points.cuda(), target.cuda()
		# 喂入模型数据
         pred, trans_feat = classifier(points)

model

  • Classification Model和Segmentation Model,区别在于PointNetEncoder的参数global_feature=true | false
  • 这里以Classification Model为例:
class get_model(nn.Module):
    def __init__(self, k=40, normal_channel=True):
        super(get_model, self).__init__()
        if normal_channel:
            channel = 6
        else:
            channel = 3
        self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=channel)	# classification model
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.4)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)	# 这里计算出
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        # 每个点云计算对应的类别损失
        x = F.log_softmax(x, dim=1)	# bs, k
        return x, trans_feat
  • 如果是segmentation model,只需对如上代码做如下修改
    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans_feat

PointNetEncoder

class PointNetEncoder(nn.Module):
    def __init__(self, global_feat=True, feature_transform=False, channel=3):
        super(PointNetEncoder, self).__init__()
        self.stn = STN3d(channel)   # 计算第一个3x3的transformation矩阵
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)	# 计算第二个transformation矩阵

    def forward(self, x):
        B, D, N = x.size()
        trans = self.stn(x) # bs, 3, 3
        x = x.transpose(2, 1)   # bs, n, 3
        if D > 3:
            feature = x[:, :, 3:]
            x = x[:, :, :3]
        x = torch.bmm(x, trans) #
        if D > 3:
            x = torch.cat([x, feature], dim=2)
        x = x.transpose(2, 1)   # bs, c, n
        x = F.relu(self.bn1(self.conv1(x))) # [bs, 3, n]->[bs, 64, n]

        if self.feature_transform:
            trans_feat = self.fstn(x)   # bs, 64, 64
            x = x.transpose(2, 1)
            x = torch.bmm(x, trans_feat)    # bs, n, 64
            x = x.transpose(2, 1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)    # bs, 1024:这里是global feature
        # 如果global_feat=true, 返回的x是global feature
        if self.global_feat:
            return x, trans, trans_feat	# 分别是global feature, 3x3的转换矩阵, transform feature,用于instance-wise的classfication
        # 返回的x为每个点加入global feature之后的特征,用于segmentation
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, N)
            # x:[bs, c, n] + pointfeat:[bs, c, n] -> [bs, c+c, n]
            return torch.cat([x, pointfeat], 1), trans, trans_feat	# 这里是point-wise的classfication

你可能感兴趣的:(点云处理,源码解读,深度学习,pytorch,python)