【Pytorch小知识】torch.gather()函数的用法及在Focal Loss中的应用(详细易懂)

文章目录

  • 官方文档中torch.gather的用法
  • torch.gather应用
  • 总结


官方文档中torch.gather的用法

torch.gather(input, dim, index, out=None) → Tensor

    Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters:	

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

gather的作用:根据维度(dim)和索引(index)获取输入(input)的某些特殊值。

根据官方文档中的例子(Example),t是2×2张量:
[ 1 2 3 4 ] \left[ \begin{matrix} 1 & 2 \\ 3 & 4 \\ \end{matrix} \right] [1324]
索引是2×2张张量:
[ 0 0 1 0 ] \left[ \begin{matrix} 0 & 0 \\ 1 & 0 \\ \end{matrix} \right] [0100]
按dim=1,也就是列维度取索引值(2×2索引就是0或1代表第一列和第二列),所以根据索引与t的对应关系得到gather后的结果:
[ 1 1 4 3 ] \left[ \begin{matrix} 1 & 1 \\ 4 & 3 \\ \end{matrix} \right] [1413]
解释:t的第0列元素是1,第一行索引都是0,所以gather后的结果第一行就都是1;t的第1列元素是4,第0列元素是3,按照索引的对应位置得到gather后的结果。


torch.gather应用

torch.gather的应用:用于分类问题中取出某个对应类别的预测值。具体应用于softmax或log_softmax中。
例:

#假设三张图片,总共有五个类别
preds = torch.randn((3,5))
print(preds)
#三张图片所属类别真实标签为2,3,4
labels = torch.tensor([2,3,4])

得到的3×5张量如下图:
【Pytorch小知识】torch.gather()函数的用法及在Focal Loss中的应用(详细易懂)_第1张图片
然后log_softmax和softmax

preds_logsoft = F.log_softmax(preds, dim=1) 
print(preds_logsoft)
preds_softmax = torch.exp(preds_logsoft)
print(preds_softmax)

得到结果为:
【Pytorch小知识】torch.gather()函数的用法及在Focal Loss中的应用(详细易懂)_第2张图片
用gather函数获取labels中2,3,4类别的预测值

preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) 
print(preds_softmax)   
preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1))
print(preds_logsoft)    

其中,labels.view(-1,1)是只保留1列,自动填充行数,于是1×3张量就变成了3×1张量
【Pytorch小知识】torch.gather()函数的用法及在Focal Loss中的应用(详细易懂)_第3张图片
再应用gather函数就可以得到最终结果:
【Pytorch小知识】torch.gather()函数的用法及在Focal Loss中的应用(详细易懂)_第4张图片
之后就可以根据得到的指定类别预测值计算loss了。


总结

  1. gather函数作用是获得张量中某些特定的值。
  2. gather函数常用于softmax和log_softmax中。
  3. gather函数常见于各种损失函数之中,如NLL_Loss、CrossEntropy、Focal Loss等。

你可能感兴趣的:(Pytorch小知识,python,pytorch,深度学习,机器学习,神经网络)