PointNet系列代码复现详解(2)—PointNet_part_seg

目录

train_partseg.py


这次是PointNet_part_seg的代码,与PointNet分类代码一样的部分就不在提及了。

 

train_partseg.py

映射关系改变

seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
               'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
               'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
               'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
    for label in seg_classes[cat]:
        seg_label_to_cat[label] = cat

 将神经网络模型中的 ReLU 激活函数设置为原地(inplace)操作

原地操作表示将会进行原地操作(即直接在原有的内存空间上修改数据),从而减少内存的使用,提高模型的运行效率。

def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace = True

将整数标签转换为独热编码(one-hot encoding) 

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y

命令行参数设置,也是整个网络运行的基本设置,具体含义见后。 与分类代码基本一样

def parse_args():
    parser = argparse.ArgumentParser('Model')
    parser.add_argument('--model', type=str, default='pointnet_part_seg', help='model name')
    parser.add_argument('--batch_size', type=int, default=16, help='batch Size during training')
    parser.add_argument('--epoch', default=251, type=int, help='epoch to run')
    parser.add_argument('--learning_rate', default=0.001, type=float, help='initial learning rate')
    parser.add_argument('--gpu', type=str, default='0', help='specify GPU devices')
    parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD')
    parser.add_argument('--log_dir', type=str, default=None, help='log path')
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--npoint', type=int, default=2048, help='point Number')
    parser.add_argument('--normal', action='store_true', default=False, help='use normals')
    parser.add_argument('--step_size', type=int, default=20, help='decay step for lr decay')
    parser.add_argument('--lr_decay', type=float, default=0.5, help='decay rate for lr decay')

    return parser.parse_args()
  • --model:模型名称,类型为字符串,默认值为 'pointnet_part_seg',用于指定使用哪个模型进行训练。
  • --batch_size:批次大小,类型为整数,默认值为 16,表示在训练过程中每次传入模型的数据样本数。
  • --epoch:训练轮数,类型为整数,默认值为 251,表示训练过程中总共需要迭代多少次。
  • --learning_rate:初始学习率,类型为浮点数,默认值为 0.001,表示在训练过程中初始的学习率。
  • --gpu:指定 GPU 设备,类型为字符串,默认值为 '0',表示在训练过程中使用哪一块 GPU 设备进行计算,可以同时指定多块 GPU 设备,例如 '0,1,2'
  • --optimizer:优化器类型,类型为字符串,默认值为 'Adam',表示在训练过程中使用哪种优化器进行模型参数的更新,可以选择 'Adam''SGD'
  • --log_dir:日志路径,类型为字符串,默认值为 None,表示训练过程中保存日志文件的路径。
  • --decay_rate:权重衰减,类型为浮点数,默认值为 1e-4,表示在训练过程中的权重衰减率。
  • --npoint:点云采样数,类型为整数,默认值为 2048,表示在训练过程中每个点云模型需要采样多少个点。
  • --normal:是否使用法向量,类型为布尔值,默认值为 False,表示在训练过程中是否使用点云模型的法向量信息。
  • --step_size:学习率衰减步长,类型为整数,默认值为 20,表示在训练过程中学习率需要下降的迭代步数。
  • --lr_decay:学习率衰减倍数,类型为浮点数,默认值为 0.5,表示在每个学习率衰减步长结束后学习率需要下降的倍数。

main(args)一开始与分类训练代码中一样,都是一些基本参数的设置,训练信息保存的路径,训练数据的加载,模型读取。 

这段代码定义了一个名为 weights_init 的函数,用于对模型参数进行初始化。具体来说,当模型的某个子模块是 Conv2dLinear 类型时,将其权重矩阵使用 Xavier 初始化方法进行初始化,将其偏置向量全部初始化为 0。

而Xavier 初始化方法是一种常用的权重初始化方法,它可以使模型在训练过程中更快地收敛,并提高模型的泛化性能。在这里,使用该函数对模型参数进行初始化,可以提高模型的训练效果。

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.constant_(m.bias.data, 0.0)

之后就是查看是否有预训练模型,然后设置优化器参数。

这段代码定义了一个名为 bn_momentum_adjust 的函数,用于调整 Batch Normalization 层的动量参数。具体来说,当模型的某个子模块是 BatchNorm2dBatchNorm1d 类型时,将其动量参数设置为函数输入的 momentum 值。

Batch Normalization 层是一种常用的神经网络层,它可以加速模型的训练过程,并提高模型的泛化性能。其中,动量参数 momentum 用于平滑 Batch Normalization 层中均值和方差的计算过程。在这里,使用该函数对模型中的 Batch Normalization 层的动量参数进行设置,可以进一步优化模型的训练效果。

    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
            m.momentum = momentum

学习率裁剪的阈值 LEARNING_RATE_CLIP,原始动量值 MOMENTUM_ORIGINAL,动量衰减的比例 MOMENTUM_DECCAY,以及动量衰减的步数 MOMENTUM_DECCAY_STEP

具体来说,LEARNING_RATE_CLIP 用于限制学习率的最小值,避免学习率过小导致收敛缓慢。MOMENTUM_ORIGINAL 表示模型中 Batch Normalization 层的原始动量值,MOMENTUM_DECCAY 表示动量衰减的比例,用于控制动量在每个训练阶段的变化程度。MOMENTUM_DECCAY_STEP 则表示动量衰减的步数,即每经过 MOMENTUM_DECCAY_STEP 个训练 epoch,动量的值就乘上 MOMENTUM_DECCAY

LEARNING_RATE_CLIP = 1e-5
MOMENTUM_ORIGINAL = 0.1
MOMENTUM_DECCAY = 0.5
MOMENTUM_DECCAY_STEP = args.step_size
# 最佳准确率 best_acc,全局训练轮次 global_epoch,
# 最佳类别平均交并比 best_class_avg_iou,以及最佳实例平均交并比 best_instance_avg_iou。
    best_acc = 0
    global_epoch = 0
    best_class_avg_iou = 0
    best_inctance_avg_iou = 0

这段代码实现了学习率衰减和动量衰减的功能,并对模型中的 Batch Normalization 层的动量参数进行调整。

具体来说,代码首先计算当前的学习率 lr,该学习率通过将初始学习率 args.learning_rate 乘以 args.lr_decayepoch // args.step_size 次方来计算,同时应用了学习率裁剪的阈值 LEARNING_RATE_CLIP

然后,代码将计算得到的学习率 lr 应用到优化器中的参数组上,从而更新模型的参数。

接下来,代码计算当前的动量值 momentum,该动量值通过将原始动量值 MOMENTUM_ORIGINAL 乘以 MOMENTUM_DECCAYepoch // MOMENTUM_DECCAY_STEP 次方来计算。如果计算得到的动量值小于 0.01,则将其调整为 0.01

最后,代码使用 bn_momentum_adjust 函数将计算得到的动量值应用到模型的 Batch Normalization 层中,并将模型设置为训练模式。

lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
log_string('Learning rate:%f' % lr)
# 遍历了优化器中的所有参数组(param_groups),然后将学习率(lr)设置为指定的值
for param_group in optimizer.param_groups:
    param_group['lr'] = lr
momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
if momentum < 0.01:
    momentum = 0.01
print('BN momentum updated to: %f' % momentum)
classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
classifier = classifier.train()

示例,假设初始学习率为0.1,衰减因子为0.5,衰减步数为10,学习率裁剪阈值为0.01,当前轮次为20。则,学习率的计算方式如下:

lr = max(0.1 * (0.5 ** (20 // 10)), 0.01) = max(0.1 * 0.5 ** 2, 0.01) = max(0.025, 0.01) = 0.025

因此,当前轮次下的学习率为0.025。随着轮次的增加,学习率会逐渐减小。当轮次为30时,学习率的计算方式如下:

lr = max(0.1 * (0.5 ** (30 // 10)), 0.01) = max(0.1 * 0.5 ** 3, 0.01) = max(0.0125, 0.01) = 0.0125

因此,在轮次为30时,学习率为0.0125。随着轮次的增加,学习率会继续减小,直到达到学习率裁剪的阈值。

momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))

这个代码使用了一个指数衰减的方法来调整动量。具体来说,我们设置了一个原始的动量值(MOMENTUM_ORIGINAL),然后在每个阶段(MOMENTUM_DECCAY_STEP)中将动量进行衰减(MOMENTUM_DECCAY),直到达到指定的epoch数。这个衰减过程可以使得动量随着训练的进行而逐渐减小,从而使得滑动平均窗口逐渐变大,提高了模型的准确率。  

BN momentum是指在批量归一化(Batch Normalization)中使用的指数加权平均值的更新速度,它决定了历史信息的保留程度。

具体来说,当我们对一个batch中的数据进行归一化时,会计算出这个batch的均值和方差,并使用它们来标准化数据。但是,由于每个batch的数据分布可能不同,因此使用当前batch的均值和方差进行标准化可能会导致模型泛化能力下降。为了解决这个问题,我们引入了指数加权平均值,用于维护所有batch的均值和方差的移动平均值。具体地,对于每个batch中的均值和方差,我们都会维护一个对应的移动平均值,而这些移动平均值会被使用在下一个batch的标准化中。BN momentum就是用来控制这些移动平均值的更新速度的超参数,它表示移动平均值在更新时受到前一次移动平均值的影响程度。例如,如果BN momentum取0.9,则新的移动平均值将以90%的权重加入到原有的移动平均值中,而原有的移动平均值则以10%的权重保留。

适当调整BN momentum可以帮助我们平衡当前batch和之前所有batch的均值和方差,从而提高模型的泛化性能。常见的BN momentum取值范围是0.9到0.999,一般来说,较小的值会使模型更容易过拟合,而较大的值则可能导致模型难以收敛。

之后就是开始训练,与分类代码部分一致先优化器清零,然后数据增强,之后送入网络训练。

具体来说,这个代码中的第一行classifier(points, to_categorical(label, num_classes))表示对一个batch的数据(points)进行前向传播计算,生成模型的预测结果(seg_pred)和转换特征(trans_feat)。其中,to_categorical(label, num_classes)将标签(label)转换为一个one-hot向量,以便在模型中使用。

第二行代码seg_pred.contiguous().view(-1, num_part)将预测结果(seg_pred)进行形状变换,将其变为一个2D矩阵,其中每行表示一个数据点的预测结果,每列表示一个类别的概率值。第三行代码将目标(target)进行形状变换,将其变为一个1D张量,其中每个元素表示一个数据点的目标类别。

第四行代码seg_pred.data.max(1)[1]将预测结果(seg_pred)中每个数据点概率最大的类别作为预测结果(pred_choice),并返回一个1D张量。

seg_pred, trans_feat = classifier(points, to_categorical(label, num_classes))
seg_pred = seg_pred.contiguous().view(-1, num_part)
target = target.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1]

 结果出来,自然就是计算损失和进行反向传播更新模型的权重和偏置。

具体来说,这个代码中的第一行pred_choice.eq(target.data).cpu().sum()用于计算预测结果(pred_choice)与目标(target)相等的数据点个数(correct)。

第二行代码mean_correct.append(correct.item() / (args.batch_size * args.npoint))用于计算当前batch的预测准确率(mean_correct)。(args.batch_size和args.npoint分别表示batch的大小和每个数据点的最大数量。)

第三行代码loss = criterion(seg_pred, target, trans_feat)用于计算当前batch的损失值(loss),其中criterion是一个损失函数(seg_pred和target分别表示模型的预测结果和目标,trans_feat表示转换特征。)

第四行代码loss.backward()用于进行反向传播算法,计算模型参数的梯度。

第五行代码optimizer.step()用于根据计算的梯度更新模型的权重和偏置。

correct = pred_choice.eq(target.data).cpu().sum()
mean_correct.append(correct.item() / (args.batch_size * args.npoint))
loss = criterion(seg_pred, target, trans_feat)
loss.backward()
optimizer.step()

 训练完一轮后就要进行测试集的评估

具体来说,这个代码中的第一行with torch.no_grad()用于设置上下文环境,禁用梯度计算,以加快推理速度。

第二行代码test_metrics = {}用于定义一个字典,用于保存测试集的各项指标。

第三行代码total_correct = 0和total_seen = 0分别用于统计测试集中所有数据点的预测正确数和总数。

第四行代码total_seen_class和total_correct_class分别用于统计测试集中各个类别的数据点总数和预测正确数,其中num_part表示数据点的类别数。

第五行代码shape_ious = {cat: [] for cat in seg_classes.keys()}用于定义一个字典,用于保存每个类别的平均交并比(IOU)。

第六行代码seg_label_to_cat = {}用于定义一个字典,用于将数据点的标签映射到类别名称,其中seg_classes是一个保存类别名称的字典。

with torch.no_grad():
    test_metrics = {}
    total_correct = 0
    total_seen = 0
    total_seen_class = [0 for _ in range(num_part)]
    total_correct_class = [0 for _ in range(num_part)]
    shape_ious = {cat: [] for cat in seg_classes.keys()}
    seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}

PointNet系列代码复现详解(2)—PointNet_part_seg_第1张图片

for cat in seg_classes.keys():
    for label in seg_classes[cat]:
        seg_label_to_cat[label] = cat

具体来说,这个代码中的第一行for i in range(cur_batch_size)用于对当前batch中的每个数据点进行处理。

第二行代码cat = seg_label_to_cat[target[i, 0]]用于获取当前数据点的类别名称。

第三行代码logits = cur_pred_val_logits[i, :, :]用于获取当前数据点的预测结果的logits。

第四行代码cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]用于将logits转换为每个点的类别标签,并将其保存到cur_pred_val数组中。其中,np.argmax()函数用于获取logits中最大值的索引,即预测类别的编号;seg_classes[cat]用于获取当前类别所对应的类别编号列表,再加上[0]是因为我们通常只需要其中任意一个编号即可。最后,将预测类别编号加上当前类别的第一个编号,即可得到当前点的类别标签。

for i in range(cur_batch_size):
    cat = seg_label_to_cat[target[i, 0]]
    logits = cur_pred_val_logits[i, :, :]
    cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]

具体来说,这个代码中的第一行correct = np.sum(cur_pred_val == target)用于计算当前batch的预测结果中有多少与目标值相同。其中,cur_pred_val是模型的预测结果,target是目标值,它们都是numpy数组。在这里,cur_pred_val == target会生成一个布尔值数组,其中每个元素表示对应位置的预测值是否等于目标值。np.sum()函数用于计算所有True的数量,即当前batch中预测正确的点的数量。

第二行代码total_correct += correct用于将当前batch中预测正确的点的数量累加到总的正确预测数中。

第三行代码total_seen += (cur_batch_size * NUM_POINT)用于统计当前batch中一共有多少个点,其中cur_batch_size是当前batch的大小,NUM_POINT是每个点云中的最大点数。

correct = np.sum(cur_pred_val == target)
total_correct += correct
total_seen += (cur_batch_size * NUM_POINT)

 这段代码用于统计当前batch中每个类别的预测结果中有多少与目标值相同,并将其累加到每个类别的总的正确预测数中。同时,也统计了当前batch中每个类别一共有多少个点。

具体来说,这个代码中的第一行for l in range(num_part)用于对每个类别进行处理。其中,num_part是数据集中类别的数量。

第二行代码total_seen_class[l] += np.sum(target == l)用于统计当前batch中每个类别一共有多少个点,其中target == l会生成一个布尔值数组,其中每个元素表示对应位置的目标值是否属于第l类。np.sum()函数用于计算所有True的数量,即当前batch中第l类的点的数量。

第三行代码total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))用于统计当前batch中每个类别的预测结果中有多少与目标值相同,其中(cur_pred_val == l) & (target == l)会生成一个布尔值数组,其中每个元素表示对应位置的预测值和目标值是否都属于第l类。np.sum()函数用于计算所有True的数量,即当前batch中第l类预测正确的点的数量。

for l in range(num_part):
    total_seen_class[l] += np.sum(target == l)
    total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))

 这段代码用于计算每个形状的每个部分的IoU,并将其保存在一个列表中,最终计算每个形状的平均IoU。

具体来说,这个代码中的第一行for i in range(cur_batch_size)用于对当前batch中的每个点云进行处理。

第二行代码segp = cur_pred_val[i, :]和第三行代码segl = target[i, :]用于获取当前点云的预测结果和目标值。

第四行代码cat = seg_label_to_cat[segl[0]]用于获取当前点云所属的形状的类别。

第五行代码part_ious = [0.0 for _ in range(len(seg_classes[cat]))]用于初始化一个长度为当前形状部分数量的全零列表,用于保存每个部分的IoU。

第六行代码for l in seg_classes[cat]:用于对当前形状的每个部分进行处理。其中,seg_classes是一个字典,用于保存每个类别的部分标签。

第七行代码if (np.sum(segl == l) == 0) and (np.sum(segp == l) == 0):用于判断当前部分是否在目标值和预测结果中都不存在,如果是,则将该部分的IoU设为1.0。

第九行代码else:用于计算当前部分的IoU,具体来说,它先计算当前部分在目标值和预测结果中同时存在的点的数量,然后除以当前部分在目标值和预测结果中存在的点的数量的并集。

第十二行代码shape_ious[cat].append(np.mean(part_ious))用于将当前形状的平均IoU添加到一个列表中,该列表用于保存每个形状的平均IoU。 

for i in range(cur_batch_size):
    segp = cur_pred_val[i, :]
    segl = target[i, :]
    cat = seg_label_to_cat[segl[0]]
    part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
    for l in seg_classes[cat]:
        if (np.sum(segl == l) == 0) and (np.sum(segp == l) == 0):  # part is not present, no prediction as well
             part_ious[l - seg_classes[cat][0]] = 1.0
        else:
             part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(np.sum((segl == l) | (segp == l)))
    shape_ious[cat].append(np.mean(part_ious))

下面就是把当前在测试集上的结果与以前的进行比较,保存最好的那个。以及将一些结果打印出来并写入训练日志。

在测试集上计算出的指标值与历史最佳指标值进行比较。如果当前的inctance_avg_iou大于等于历史最佳值best_inctance_avg_iou,则保存当前模型,并更新历史最佳指标值。

        if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
            logger.info('Save model...')
            savepath = str(checkpoints_dir) + '/best_model.pth'
            log_string('Saving at %s' % savepath)
            state = {
                'epoch': epoch,
                'train_acc': train_instance_acc,
                'test_acc': test_metrics['accuracy'],
                'class_avg_iou': test_metrics['class_avg_iou'],
                'inctance_avg_iou': test_metrics['inctance_avg_iou'],
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, savepath)
            log_string('Saving model....')

pointnet_part_seg.py

下面就是具体的神经网络设计了

前面定义的是各个模块,然后就是前向传播

    def __init__(self, part_num=50, normal_channel=True):
        super(get_model, self).__init__()
        if normal_channel:
            channel = 6
        else:
            channel = 3
        self.part_num = part_num
        self.stn = STN3d(channel)
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, 512, 1)
        self.conv5 = torch.nn.Conv1d(512, 2048, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(2048)
        self.fstn = STNkd(k=128)
        self.convs1 = torch.nn.Conv1d(4944, 256, 1)
        self.convs2 = torch.nn.Conv1d(256, 256, 1)
        self.convs3 = torch.nn.Conv1d(256, 128, 1)
        self.convs4 = torch.nn.Conv1d(128, part_num, 1)
        self.bns1 = nn.BatchNorm1d(256)
        self.bns2 = nn.BatchNorm1d(256)
        self.bns3 = nn.BatchNorm1d(128)

    def forward(self, point_cloud, label):
        B, D, N = point_cloud.size()
        trans = self.stn(point_cloud)
        point_cloud = point_cloud.transpose(2, 1)
        # D大于3,则将点云分为坐标信息和特征信息两部分,其中坐标信息是三维的,用于计算STN变换,特征信息可以是任意维度的。
        # 点云分开是为了将变换矩阵trans应用到点云的坐标信息上,实现点云的空间变换。
        if D > 3:
            point_cloud, feature = point_cloud.split(3, dim=2)
        point_cloud = torch.bmm(point_cloud, trans)
        # 再坐标和特征信息合起来  不大于3就忽略
        if D > 3:
            point_cloud = torch.cat([point_cloud, feature], dim=2)

        point_cloud = point_cloud.transpose(2, 1)

        out1 = F.relu(self.bn1(self.conv1(point_cloud)))
        out2 = F.relu(self.bn2(self.conv2(out1)))
        out3 = F.relu(self.bn3(self.conv3(out2)))

        # 空间变换
        trans_feat = self.fstn(out3)
        x = out3.transpose(2, 1)
        net_transformed = torch.bmm(x, trans_feat)
        net_transformed = net_transformed.transpose(2, 1)

        out4 = F.relu(self.bn4(self.conv4(net_transformed)))
        out5 = self.bn5(self.conv5(out4))
        out_max = torch.max(out5, 2, keepdim=True)[0]
        out_max = out_max.view(-1, 2048)

        out_max = torch.cat([out_max,label.squeeze(1)],1)
        expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, N)
        concat = torch.cat([expand, out1, out2, out3, out4, out5], 1)
        net = F.relu(self.bns1(self.convs1(concat)))
        net = F.relu(self.bns2(self.convs2(net)))
        net = F.relu(self.bns3(self.convs3(net)))
        net = self.convs4(net)
        net = net.transpose(2, 1).contiguous()
        net = F.log_softmax(net.view(-1, self.part_num), dim=-1)
        net = net.view(B, N, self.part_num) # [B, N, 50]

        return net, trans_feat

你可能感兴趣的:(深度学习,人工智能)