torch.einsum 爱因斯坦求和约定

torch.einsum是一个强大的函数,用于执行爱因斯坦求和约定(Einstein summation convention)。它可以简洁地表达复杂的张量运算。

  1. 对于 l_pos = torch.einsum('nc,nc->n', [q, k])

    • ‘nc,nc->n’ 是一个表示运算规则的字符串。
    • ‘nc’ 表示一个形状为 (N, C) 的张量,N 是批次大小,C 是特征维度。
    • 这个操作等同于矩阵乘法后的对角线元素,或者说是每对向量的点积。

    示例:

    q = torch.tensor([[1, 2], [3, 4]])
    k = torch.tensor([[5, 6], [7, 8]])
    result = torch.einsum('nc,nc->n', [q, k])
    # 等价于 
    # result = torch.sum(q * k, dim=1)
    # 结果: tensor([17, 53])
    
  2. 对于 l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

    • ‘nc,ck->nk’ 表示两个矩阵的乘法。
    • ‘nc’ 是形状为 (N, C) 的查询张量。
    • ‘ck’ 是形状为 (C, K) 的队列张量,K 是队列长度。
    • 结果是一个形状为 (N, K) 的张量。

    示例:

    q = torch.tensor([[1, 2], [3, 4]])
    queue = torch.tensor([[5, 6, 7], [8, 9, 10]])
    result = torch.einsum('nc,ck->nk', [q, queue])
    # 等价于
    # result = torch.matmul(q, queue)
    # 结果: tensor([[21, 24, 27],
    #               [47, 54, 61]])
    

einsum的优势:

  1. 灵活性:可以用简洁的符号表示复杂的张量运算。
  2. 效率:在某些情况下比显式循环更高效。
  3. 可读性:一旦熟悉了符号,代码变得更易读。

你可能感兴趣的:(深度学习,人工智能,pytorch)