[pytorch] 矩阵基本操作-从实际需求出发

从最近的一个实际需求出发,对pytorch的矩阵基本方法进行一个汇总

[背景设定] 半监督学习,十分类场景,针对每个类 c 会根据标签数据计算出一个阈值 τ \tau τc ,用于过滤无标签数据

[需求] 模型获取到一个batch的数据,计算出 logit (并做了softmax()处理),随即需要根据其预测的最大概率的类别与预测值进行过滤,此时涉及到两次选择
[pytorch] 矩阵基本操作-从实际需求出发_第1张图片

  • logit 先选出最大预测概率值和最大预测idx,作为其伪标签p-label=idx

  • 再从 τ \tau τ 列表中选出对应的 τ \tau τp-label

  • 单个batch中每张图片所得到的 p-label 和对应的 τ \tau τp-label 均不一样

[基础版解决方案]

遍历batch中的每个logit,对每个logit单独做处理和筛选,问题:

  • 效率低,这里的效率涉及到python本身的低效,也考虑到batchsize过大可能会因此拖累模型整体的训练速度,产生瓶颈

[进阶解决方案]

  • 使用scatter生成one-hot矩阵,从而实现对tau的筛选

  • tau快速复制为bs份,使用pytorch的广播机制加速
    [pytorch] 矩阵基本操作-从实际需求出发_第2张图片

    bs = 8
    classes = 10

    logit = torch.rand((bs,classes)) # 假设这是由模型针对一个batch计算出来的logit并已经softmax()
    tau = torch.tensor([random.random() for _ in range(classes)]) # 10 分类场景下当前使用的tau, 维度 (10)
    bc_zeros = torch.zeros((bs, classes))
    tau = tau + bc_zeros # 广播机制快速生成复制bs份tau构成 (bs, classes) 的tau矩阵

    max_pred_v, max_pred_idx = torch.max(logit, dim=1) # 获取logit中最大预测概率和最大预测值
    mask = torch.zeros(bs, classes).scatter_(1, max_pred_idx.unsqueeze(1), 1) # 处理为 one-hot 矩阵,用于筛选出需要的 tau-idx

    class_threshold = mask * tau
    class_threshold1 = torch.max(class_threshold, dim=1)[0]

    class_threshold2 = torch.mm(mask, tau.t())[:, 0]
    # ! class_threshold1 和 class_threshold2 能得到相同的结果

[仍然存在的问题]

  • 实际上使用one-hot矩阵筛选出来的tau矩阵有着bs份重复信息or无用信息,最后取用所需数据的时候仍然要处理一次。

你可能感兴趣的:(pytorch,矩阵,深度学习)