【3D点云识别】PointNet++论文及代码解读

PointNet++论文及代码理解

  • 解决什么问题
  • 本文创新点\贡献
  • 前人方法
  • 方法
    • 问题定义
    • 方法概述
    • Hierarchical Point Set Feature Learning
    • Robust Feature Learning under Non-Uniform Sampling Density
    • Point Feature Propagation for Set Segmentation
  • 代码
    • 数据读取
    • 基本函数
    • 网络结构
    • 错误记录
    • 实验结果


解决什么问题

  1. 如何创建点集的分割
  2. 如何通过局部特征学习来抽象点集或局部特征

本文创新点\贡献

利用多尺度的特征点来实现鲁棒性和特征值的获取,训练的时候使用dropout,还能学会适应性,检测不同尺度并结合多尺度特征


前人方法

PointNet无法捕获不同尺度的局部上下文信息


方法

【3D点云识别】PointNet++论文及代码解读_第1张图片

问题定义

假定 X = ( M , d ) X = (M,d) X=(M,d)是一个离散度量空间,度量继承自欧氏空间 R n R^n Rn,其中 M ⊆ R n M \sube R^n MRn是点集, d d d是距离度量,函数 f f f输入 X X X输出跟 X X X相关的语义信息

f f f也就是卷积了

方法概述

先根据空间距离的度量将很多点划分到一些重叠的区域中,从小的邻距中来不断的提取特征,局部的信息会被聚集到越来越大的方块中,直到提取了全部的特征

为什么点集划分的区域之间要重叠呢?
是不是类似CNN的思想,CNN对 3 × 3 3\times 3 3×3卷积,必然有重叠的cell被卷积核获取,而这个获得重叠,应该也是这个意思
答:对,就是这样,每个ball有多个点(32)

用FPS来找中心,然后用ball query来确定中心周围的点,这些点的特征一起池化得到一个点的特征,这样几次之后再差值上采样回到原来的分辨率


Hierarchical Point Set Feature Learning

【3D点云识别】PointNet++论文及代码解读_第2张图片
输入输出
输入是 N × ( d + C ) N\times (d +C) N×(d+C),其中N是点数,d是坐标系的维度,C是点特征的维度
输出 N ′ × ( d + C ′ ) N' \times(d+C') N×(d+C) N ′ N' N是d维坐标系中采样点的数量, C ′ C' C维特征向量总结局部上下文信息

Group方法选择
在卷积神经网络中,像素的局部区域由像素曼哈顿距离(核大小)内的阵列索引像素组成,点的邻域由度量距离确定。作者也考虑了关于每个group的点的设置,有kNN方法和ball query方法,kNN能固定数量,但是ball query的方法能固定大小,更有利于局部区域特征的归一化,对于需要局部模式识别的任务更好(例如语义点标记)。

在Group的时候,使用分层Group的方法,能随着分层提取越来越大的局部区域,有三个重要的层组成:

Sampling layer
提取一些点,然后根据点定义局部区域的形心
输入n个点,最远点采样m个点。
相比于CNN,以数据依赖的方式生成接受域

Grouping layer
通过在形心周围寻找邻近点来建造局部区域
输入是 N × ( d + C ) N\times (d +C) N×(d+C)大小的点集,还有大小为 N ′ × d N' \times d N×d的一些形心坐标

就是说分成了多组, N ′ N' N应该就表示组的数量
答:对

输出是大小为 N ′ × K × ( d + C ) N' \times K \times (d+C) N×K×(d+C)的点集的group
每个group都是以一个局部区域, K K K是形心周围的点的数量。
K在不同的group中数量是不同的,但是PointNet层能将他们转换成一个固定长度的局部区域特征向量

K不同是因为什么呢?难道是最远点采样获得形心,然后在形心周围根据半径来获取点?
答:对

PointNet layer
编码局部区域到特征向量
输入是 N ′ × K × ( d + C ) N' \times K \times (d+C) N×K×(d+C),输出是 N ′ × ( d + C ′ ) N' \times(d +C') N×(d+C)
每个输入的group都会进行归一化, x i j = x i j − x ^ j , i = 1 , . . . , K   j = 1 , . . . , d x^j_i = x^j_i-\hat{x}^j,i=1,...,K\ j=1,...,d xij=xijx^j,i=1,...,K j=1,...,d x ^ \hat{x} x^是形心


Robust Feature Learning under Non-Uniform Sampling Density

在密集数据中获得的特征可能不适用于稀疏采样区域,说是想尽可能近的获取稠密采样区域的最好的细节,但是这样在稀疏区不被允许,可能会因为采样破坏结构

是说如果想要获得稠密区的更多细节,就要间隔很小地来获取特征,而在稀疏区这样就搞坏了,点太少?

所以提出了一个密度适应性PointNet层,当采样密度改变的时候,结合来自不同的尺寸的区域的特征
【3D点云识别】PointNet++论文及代码解读_第3张图片
( a ) : Multi-scale grouping (MSG) ; ( b ) Multi-resolution grouping (MRG) (a): \text{Multi-scale grouping (MSG)}; (b) \text{Multi-resolution grouping (MRG)} (a):Multi-scale grouping (MSG);(b)Multi-resolution grouping (MRG)

两种group和结合不同sacle特征的方式:
Multi-scale grouping (MSG):
一种优化集合的多尺度的特征的策略,就是随机dropout输入的instance点

是为了让多个尺度特征互补吗,而不是简单的结合?
答:算是这样

实验中设置0.95的阈值

Multi-resolution grouping (MRG)
上面方法消耗大,这个方法利用了不同层级点数量不同的特性
由两个向量组成,一个是底层 L i − 1 L_{i-1} Li1的区域提取的特征,一个是原始的点的局部区域提取的特征
如果点很稀疏的话,第一个向量就不够准确了,要扩大第二个向量的权重;如果点很稠密,那第一个向量就能获得较好的细节


Point Feature Propagation for Set Segmentation

将下采样的点的特征传递到原始点

有点意思,或许是根据group来做的?根据度量距离,多少距离以内特征相同?

【3D点云识别】PointNet++论文及代码解读_第4张图片
使用分层传播策略,基于distance的差值还有交叉level跳跃连接,通过在 N l − 1 N_{l-1} Nl1点的坐标系中插入 N l N_l Nl的特征值 f f f来传播
在这里插入图片描述
使用基于k最近邻的距离倒数加权平均

特征图的插值方式

插值后的特征在做跳跃连接,然后经过一个 unit pointnet,再加上fc和ReLu,重复直到恢复原始点的数目


代码

代码参考自:链接,然后我也简单实现了一下,加了些注释:链接。
这里只有MSG版本,普通版本没放上来,可以在上面的链接找到。

数据读取

class S3DISDataset(Dataset):
    def __init__(self, num_points=4096,
                 split='train',
                 root_dir=None,
                 test=6,
                 block_size=1.0,
                 sample_rate=1.0,
                 transform_control=None,
                 num_class=13
                 ):
        super(S3DISDataset, self).__init__()
        self.rooms_point = []
        self.rooms_label = []
        self.rooms_min_coor = []
        self.rooms_max_coor = []
        self.rooms_index = []
        self.block_size = block_size
        self.num_points = num_points
        self.transform_control = transform_control
        num_points_all = []

        rooms_dir = os.listdir(root_dir)
        # get file name
        if split == 'train':
            rooms = [room for room in rooms_dir if 'Area_{}'.format(test) not in room]
        else:
            rooms = [room for room in rooms_dir if 'Area_{}'.format(test) in room]
        # total 13 classes
        label_num = np.zeros(num_class)

        # 根据room来划分
        for i in range(len(rooms)):
            points = np.load(os.path.join(root_dir, rooms[i]))
            self.rooms_point.append(points[:, :6])
            self.rooms_label.append(points[:, -1])
            self.rooms_min_coor.append(np.amin(points[:, :3], axis=0))
            self.rooms_max_coor.append(np.amax(points[:, :3], axis=0))
            # Error 1: second parameter : 'range(14)' = 'bins=13'
            number_labels, _ = np.histogram(points[:, -1], bins=num_class)
            label_num += number_labels.astype(np.float32)
            # 每次只计算room中点的总数
            num_points_all.append(points.shape[0])

        # 计算label权重
        self.label_weight = label_num / np.sum(label_num)
        # Error 2:TypeError: return arrays must be of ArrayType
        # self.label_weight = np.sqrt((np.amax(self.label_weight)/ self.label_weight), 1/3)
        # careless
        self.label_weight = np.power((np.amax(self.label_weight)/ self.label_weight), 1/3)

        # 分配room比例
        sample_per_room = num_points_all / np.sum(num_points_all).astype(np.float)
        # 共能采样多少组点
        number_group_per_room_sample = (np.sum(num_points_all) * sample_rate) / num_points
        # 每个room能分到多少组
        rooms_index = []
        for i in range(len(num_points_all)):
            rooms_index.extend([i] * int(round(sample_per_room[i] * number_group_per_room_sample)))
        self.rooms_index = np.array(rooms_index)

        print('Load dataset over! Total rooms: {} , Total points: {}'.format(len(num_points_all), np.sum(num_points_all)))


    def __getitem__(self, index):
        room_index = self.rooms_index[index]
        points = self.rooms_point[room_index]
        labels = self.rooms_label[room_index]
        num_points = points.shape[0]
        # 随机一个中心点,以中心点为中心采样
        while True:

            center = points[np.random.choice(num_points), :3]
            # Error 3:careless
            block_points_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0]
            block_points_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0]
            # 不知这样可不可行,能广播吗?
            # 不行
            # Error 4:ufunc 'bitwise_and' not supported for the input types
            # & 需要两边都有括号
            # [False,True] & [True,False] = True
            # (n,) need to [0]
            select_points_block_index = np.where((points[:, 0] <= block_points_max[0]) & (points[:, 0] > block_points_min[0])
                                                 & (points[:, 1] <= block_points_max[1]) & (points[:, 1] >= block_points_min[1]))[0]
            # why 1024?
            if len(select_points_block_index) > 1024:
                break

        # 在选定的block中采样
        if len(select_points_block_index) > self.num_points:
            select_points_index = np.random.choice(select_points_block_index, self.num_points, replace=False)
        else:
            select_points_index = np.random.choice(select_points_block_index, self.num_points, replace=True)

        select_points = points[select_points_index, :]
        select_labels = labels[select_points_index]

        feature_points = np.zeros([select_points.shape[0], 9],)
        # 广播否 yes
        # Error 7: self.rooms_max_coor is for room, so need room_index
        feature_points[:, 6:9] = select_points[:, :3] / self.rooms_max_coor[room_index]
        feature_points[:, 3:6] = select_points[:, -3:] / 255.0
        feature_points[:, :3] = select_points[:, :3] - center

        if self.transform_control:
            feature_points, select_labels = self.transform(feature_points, select_labels)

        return feature_points, select_labels



    def transform(self, points, labels):
        return points, labels


    def __len__(self):
        return len(self.rooms_index)

基本函数

代码不复杂,先看一些功能函数:
FPS

def farthest_point_sample(xyz, num_sample):
    device = xyz.device
    B, N, _ = xyz.size()
    # 随机选择第一个点
    # Error 8: size(B, 1) is different to (B, )
    # (B, ) can be given
    first_point = torch.randint(0, N, (B, ), dtype=torch.long).to(device)
    farthest_index = first_point
    # 用来存放得到的点的index
    centroid_index = torch.zeros(B, num_sample, dtype=torch.long).to(device)
    # 算 没被选中的点 到 被选中的点 之间的最小距离
    dist_all = torch.ones(B, N).to(device) * 1e10
    # 使用这个就会变成每个第一维度都对应一个farthest,而不是一对一,要注意
    batch_list = torch.arange(B, dtype=torch.long).to(device)
    for i in range(num_sample):
        centroid_index[:, i] = farthest_index
        centroid_point = xyz[batch_list, farthest_index, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid_point) ** 2, dim=-1)
        # Error 8: need float 32
        mask = dist < dist_all
        dist_all[mask] = dist[mask]
        farthest_index = torch.argmax(dist_all, dim=-1)

    return centroid_index

Ball query

def ball_group(xyz, center_point, radius, n_sample):
    B, N, _ = xyz.size()
    _, S, _ = center_point.shape
    matrix_dist = get_distance(xyz, center_point)
    matrix_dist_sorted, matrix_dist_sorted_idx = torch.sort(matrix_dist, dim=-1, descending=False)
    # B M S C
    group = matrix_dist_sorted[:, :, :n_sample]
    group_idx = matrix_dist_sorted_idx[:, :, :n_sample]
    # 如果范围内的点不够,就用最近的点替代,也就是排序后的第一个值
    mask = group > radius
    # 找出采样出的group中index为N的,全部设置为第一值
    # Error 12: shape not can't broadcast
    group_idx[mask] = group_idx[:, :, 0].view(B, S, 1).repeat(1, 1, n_sample)[mask]

    return group_idx

ball query中是需要计算点之间的距离的,这次跟FPS的不一样,对应的点很多,所以用矩阵最好:

def get_distance(xyz, center_point):
    # 输出的NM和输入相反
    B, M, _ = xyz.size()
    _, N, _ = center_point.size()
    # 目标是构造一个B x N x M矩阵,存放distance,N是group的数量,M是所有点的数量
    sum_2_xyz = torch.sum(xyz ** 2, dim= -1).view(B, -1, M)
    sum_2_center = torch.sum(center_point ** 2, dim=-1).view(B, N, -1)
    sum_each = torch.bmm(center_point, xyz.permute(0, 2, 1))
    matrix_dist = sum_2_xyz + sum_2_center - 2 * sum_each

    return matrix_dist

上面的FPS和ball query两个函数返回的都是index,然而从index变为对应的值需要下面的函数,这个函数对维度较多的index同样适用:

def idx2point(xyz, idx):
    device = xyz.device
    B, N, _ = xyz.size()
    # 思路就是让batch维度中的数组维度和idx相同,达到一一对应的效果
    # 即使idx多一个维度,也没有关系,后面会做repeat
    # Error 11:'torch.Size' object does not support item assignment
    # idx.shape is torch.Size, so can't be assignment, need to be list
    batch_view = list(idx.shape)
    batch_view[1:] = [1] * (len(batch_view) - 1)
    repeat_times = list(idx.shape)
    repeat_times[0] = 1
    batch = torch.arange(B, dtype=torch.long).to(device).view(batch_view).repeat(repeat_times)
    point = xyz[batch, idx, :]

    return point

将上面的功能组合起来,完成采样和Group的工作,用的是下面的函数:

def sample_and_group(xyz, feature, n_group, radius, n_sample, return_fps=False):
    '''

    :param xyz: [B, num of point, 3]
    :param feature: [B, num of point, channel]
    :param n_group:
    :param radius:
    :param n_sample:
    :param return_fps:
    :return:
    '''
    # FPS算法找形心
    fps_idx = farthest_point_sample(xyz, n_group)
    fps_point = idx2point(xyz, fps_idx)
    # 用ball找形心周围的点
    group_idx = ball_group(xyz, fps_point, radius, n_sample)
    # Error 13: careless
    group_point = idx2point(xyz, group_idx)
    # Error 14: shape not support boardcast
    group_point_norm = group_point - fps_point[:, :, None, :]
    group_point_feature = idx2point(feature, group_idx)

    if feature is not None:
        new_feature = torch.cat((group_point_norm, group_point_feature), dim=-1)
    else:
        new_feature = group_point_norm

    if return_fps:
        return fps_point, new_feature, group_point, fps_idx
    else:
        return fps_point, new_feature

到这里基本的函数就写完了,下面就是网络结构部分


网络结构

这里是MSG版本的写法,从底部结构开始写,第一个是Group特征的提取,对应方法图中的左半部分
SetAbstract

class SetAbstractMSG(nn.Module):
    def __init__(self, n_group, radius, n_sample, in_channel, mlp_list, group_all=False):
        super(SetAbstractMSG, self).__init__()
        self.n_group = n_group
        self.n_sample = n_sample
        self.radius = radius
        self.group = group_all
        self.bn_list = nn.ModuleList()
        self.conv_list = nn.ModuleList()
        # Error 22:
        # 必须再循环中设置last_channel, 不然不能重置
        # must in loop to reset last_cahnnel
        # last_channel = in_channel + 3
        for i in range(len(mlp_list)):
            conv = nn.ModuleList()
            bn = nn.ModuleList()
            last_channel = in_channel + 3
            for j in range(len(mlp_list[i])):
                conv.append(nn.Conv2d(last_channel, mlp_list[i][j], 1))
                bn.append(nn.BatchNorm2d(mlp_list[i][j]))
                last_channel = mlp_list[i][j]
            self.conv_list.append(conv)
            self.bn_list.append(bn)

    def forward(self, xyz, point_feature):
        '''
        :param xyz: [B, num of point, 3]
        :param point_feature: [B, num of point, channel]
        '''
        fps_index = farthest_point_sample(xyz, self.n_group)
        fps_point = idx2point(xyz, fps_index) # [B, num_group, 3]
        new_feature_list = []
        # B is n times
        for i in range(len(self.radius)):
            ball_sample_num = self.n_sample[i]
            ball_sample_index = ball_group(xyz, fps_point, self.radius[i], ball_sample_num)
            ball_sampel_point = idx2point(xyz, ball_sample_index) # [B, group_num, sampel_num, 3]
            ball_sampel_point_norm = ball_sampel_point - fps_point[:, :, None, :]
            ball_sampel__feature = idx2point(point_feature, ball_sample_index)
            new_feature = torch.cat((ball_sampel__feature, ball_sampel_point_norm), dim=-1) # [B, group, sample, C]
            new_feature = new_feature.permute(0, 3, 1, 2)  # [B, C, group, sample]
            for j in range(len(self.conv_list[i])):
                new_feature = self.conv_list[i][j](new_feature)
                new_feature = self.bn_list[i][j](new_feature)
                new_feature = F.relu(new_feature)
            new_feature = torch.max(new_feature, dim=-1)[0]  # [B, C, group]
            new_feature = new_feature.permute(0, 2, 1) # [B, group, C]
            new_feature_list.append(new_feature)

        new_feature_list = new_feature_list
        # B x n -> B
        new_feature_list = torch.cat(new_feature_list, dim=-1)

        return fps_point, new_feature_list


之后,是特征的插值传播,论文里写的挺清楚的
Feature Propagation

class FeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp_list):
        super(FeaturePropagation, self).__init__()
        self.in_channel = in_channel
        self.conv_list = nn.ModuleList()
        self.bn_list = nn.ModuleList()
        for i in range(len(mlp_list)):
            self.conv_list.append(nn.Conv1d(in_channel, mlp_list[i], 1))
            self.bn_list.append(nn.BatchNorm1d(mlp_list[i]))
            in_channel = mlp_list[i]

    def forward(self, inter_to_xyz, inter_from_xyz, inter_to_feature, inter_from_feature):
        '''
        :param inter_to_xyz: 被插值的坐标,点多的那个
        :param inter_from_xyz: 插值的坐标,点少的那个
        :param inter_to_feature: 被插值的特征,点多的那个
        :param inter_from_feature: 插值的特征,点少的那个
        :return:
        '''
        B, N, _ = inter_to_xyz.size()
        _, S, _ = inter_from_xyz.size()

        if S == 1:
            inter_get = inter_from_feature.repeat(1, N, 1)
        else:
            matrix_dist = get_distance(inter_from_xyz, inter_to_xyz) # [B, N, S]
            dist_sort, dist_sort_idx = torch.sort(matrix_dist, dim=-1)
            dist_sort = dist_sort[:, :, :3] # [B, N, 3]
            dist_sort_idx = dist_sort_idx[:, :, :3] # [B, N, 3]
            weight = 1.0 / (dist_sort + 1e-8)
            # Error 15: use keepdim to support boardcast
            weight_norm = weight / torch.sum(weight, dim=-1, keepdim=True) # [B, N, 3]
            feature_get = idx2point(inter_from_feature, dist_sort_idx) # [B, N, 3, C]
            inter_get = torch.sum(feature_get * weight_norm[:, :, :, None], dim=2) # [B, N, C]

        if inter_to_feature is not None:
            new_feature = torch.cat((inter_to_feature, inter_get), dim=-1)
        else:
            new_feature = inter_get

        new_feature = new_feature.permute(0, 2, 1) # [B, N, C] -> [B, C, N]


        for i in range(len(self.conv_list)):
            new_feature = self.conv_list[i](new_feature)
            new_feature = self.bn_list[i](new_feature)
            # Error 21: after relu , be nan
            # in weight / same thing add 1e-8
            new_feature = F.relu(new_feature)
        new_feature = new_feature.permute(0, 2, 1) # [B, C, N] -> [B, N, C]

        return new_feature

最后就是汇总,输出最终预测结果

class PointNet2Sem_seg_MSG(nn.Module):
    def __init__(self, num_class):
        super(PointNet2Sem_seg_MSG, self).__init__()
        self.abstract_1 = SetAbstractMSG(1024, [0.05, 0.1], [16, 32], 9, [[16, 16, 32], [32, 32, 64]])
        self.abstract_2 = SetAbstractMSG(256, [0.1, 0.2], [16, 32], 32+64, [[64, 64, 128], [64, 96, 128]])
        self.abstract_3 = SetAbstractMSG(64, [0.2, 0.4], [16, 32], 128+128, [[128, 196, 256], [128, 196, 256]])
        self.abstract_4 = SetAbstractMSG(16, [0.4, 0.8], [16, 32], 256+256, [[256, 256, 512], [256, 384, 512]])
        self.propaga_1 = FeaturePropagation(512+512+256+256, [256, 256])
        self.propaga_2 = FeaturePropagation(128+128+256, [256, 256])
        self.propaga_3 = FeaturePropagation(32+64+256, [256, 128])
        self.propaga_4 = FeaturePropagation(128, [128, 128, 128])


        self.conv_1 = nn.Conv1d(128, 128, 1)
        self.bn_1 = nn.BatchNorm1d(128)
        self.dropput = nn.Dropout(p=0.5)
        self.conv_2 = nn.Conv1d(128, num_class, 1)



    def forward(self, x):
        # B N C
        xyz = x[:, :, :3]
        l_point_1, l_feature_1 = self.abstract_1(xyz, x)
        l_point_2, l_feature_2 = self.abstract_2(l_point_1, l_feature_1)
        l_point_3, l_feature_3 = self.abstract_3(l_point_2, l_feature_2)
        l_point_4, l_feature_4 = self.abstract_4(l_point_3, l_feature_3)

        inter_feat_1 = self.propaga_1(l_point_3, l_point_4, l_feature_3, l_feature_4)
        inter_feat_2 = self.propaga_2(l_point_2, l_point_3, l_feature_2, inter_feat_1)
        inter_feat_3 = self.propaga_3(l_point_1, l_point_2, l_feature_1, inter_feat_2)
        inter_feat_4 = self.propaga_4(xyz, l_point_1, None, inter_feat_3) # [B N C]

        new_feature = inter_feat_4.permute(0, 2, 1)
        new_feature = self.conv_1(new_feature)
        new_feature = self.bn_1(new_feature)
        # Error 21: forget relu
        new_feature = F.relu(new_feature)
        new_feature = self.dropput(new_feature)
        # Error 16: Conv1d
        new_feature = self.conv_2(new_feature) # [B num_class N]
        result_seg = F.log_softmax(new_feature, dim=1)
        result_seg = result_seg.permute(0, 2, 1) # [B N soft_class]

        return result_seg


错误记录

Error 1:second parameter : ‘range(14)’ = ‘bins=13’

nump.histogram中range(14)和bins=13是一样的,13个bin正好有14个边缘,这个函数的第一个返回值是数量,第二个返回值是bins的边缘

Error 2:ufunc ‘bitwise_and’ not supported for the input types

&的两边的内容要用()包起来

Error 3: size(B, 1) is different to (B, )

在随机初始化的时候,设置成(B,)能被赋值给[:, i],而(B, 1)不行

Error 4: need float 32

tensor之间比较来获得mask(True or False),使用float64的格式不行,要不成float32,最开始输入的时候使用data.float()就行了

Error 5: RuntimeError: view size is not compatible with input tensor’s size and stride

使用data.contiguous()来让其连续

Error 6: TypeError: nll_loss(): argument ‘weight’ (position 3) must be Tensor, not numpy.ndarray

F.nll_loss()的参数weight也要是tensor格式

Error 7: RuntimeError: multi-target not supported

F.nll_loss()的参数target只能是1维的,使用target[:, 0]来转变

Error 8: ZeroDivisionError: integer division or modulo by zero

0不能用来除其他

Error 9: ‘torch.Size’ object does not support item assignment

b=idx.shape之后,是不能对他赋值的,要用a=list(b),用list处理后才能赋值

Error 10: shape not can’t broadcast

广播要维度相对应,要广播的用view()None增加一个维度

Error 11: use keepdim to support boardcast

torch.sum(keepdim=True)可以保留那个维度为1,更适合广播计算

Error 12: after relu , be nan

除一个数的时候加上1e-8,不然有时候参数会变成nan


实验结果

没做多少实验,只是把普通版本和MSG版本各自跑了40个epoch看看在test的效果

普通 MSG
分数 0.791 0.796

训练的时候准确度都是0.95-0.99,到了测试集降了不少,当然也有训练时间不够的原因

你可能感兴趣的:(VOS)