3D点云基本网络模块(一):Spatial Transformer Networks(STN)

from torch.autograd import Variable
import utils

class STN(nn.Module):
    def __init__(self, num_scales=1, num_points=500, dim=3, sym_op='max', quaternion =False):
        super(STN, self).__init__()
        self.quaternion = quaternion
        self.dim = dim
        self.sym_op = sym_op
        self.num_scales = num_scales
        self.num_points = num_points

        self.conv1 = torch.nn.Conv1d(self.dim, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.mp1 = torch.nn.MaxPool1d(num_points)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        if not quaternion:
            self.fc3 = nn.Linear(256, self.dim*self.dim)
        else:
            self.fc3 = nn.Linear(256, 4)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        if self.num_scales > 1:
            self.fc0 = nn.Linear(1024*self.num_scales, 1024)
            self.bn0 = nn.BatchNorm1d(1024)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        # symmetric operation over all points
        if self.num_scales == 1:
            x = self.mp1(x)
        else:
            if x.is_cuda:
                x_scales = Variable(torch.cuda.FloatTensor(x.size(0), 1024*self.num_scales, 1))
            else:
                x_scales = Variable(torch.FloatTensor(x.size(0), 1024*self.num_scales, 1))
            for s in range(self.num_scales):
                x_scales[:, s*1024:(s+1)*1024, :] = self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points])
            x = x_scales

        x = x.view(-1, 1024*self.num_scales)

        if self.num_scales > 1:
            x = F.relu(self.bn0(self.fc0(x)))

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        if not self.quaternion:
            iden = Variable(torch.from_numpy(np.identity(self.dim, 'float32')).clone()).view(1, self.dim*self.dim).repeat(batchsize, 1)

            if x.is_cuda:
                iden = iden.cuda()
            x = x + iden
            x = x.view(-1, self.dim, self.dim)
        else:
            # add identity quaternion (so the network can output 0 to leave the point cloud identical)
            iden = Variable(torch.FloatTensor([1, 0, 0, 0]))
            if x.is_cuda:
                iden = iden.cuda()
            x = x + iden

            # convert quaternion to rotation matrix
            if x.is_cuda:
                trans = Variable(torch.cuda.FloatTensor(batchsize, 3, 3))
            else:
                trans = Variable(torch.FloatTensor(batchsize, 3, 3))
            x = utils.batch_quat_to_rotmat(x, trans)
        return x

作用:加入了一个四元空间转移网络(STN,Spatial Transformer Network),网络输入点云,输出位姿变换(旋转和平移)的参数,通过这个操作,将原始点云迁移到一个新的有利于网络学习的位姿状态,使用若干层全连接层或一维卷积提取逐点特征并处理,当然,在输出得到的特征时,也对逐点的特征向量做类似的逆变换,到达一个新的位姿。
来源:https://github.com/mrakotosaon/pointcleannet/blob/master/noise_removal/pcpnet.py

你可能感兴趣的:(python,pytorch,3d,transformer,深度学习)