Pytorch中, torch.einsum

https://blog.csdn.net/a2806005024/article/details/96462827

3)Torch矩阵乘法。

print(a_tensor)
 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])
 
print(b_tensor)
 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])
 
# 'ik, kj -> ij'语义解释如下:
# 输入a_tensor: 2维数组,下标为ik,
# 输入b_tensor: 2维数组,下标为kj,
# 输出output:2维数组,下标为ij。
# 隐含语义:输入a,b下标中相同的k,是矩阵乘法的下标,对应上面的例子2的公式
output = torch.einsum('ik, kj -> ij', a_tensor, b_tensor)
 
print(output)
 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

 

你可能感兴趣的:(Pytorch)