【torch.einsum】

参考:https://www.cnblogs.com/mengnan/p/10319701.html

爱因斯坦简记法,能简洁表示各种矩阵向量的操作,例如矩阵转置、乘法、求和等等,pytorch中调用API为torch.einsum,第一个参数是字符串,表示对张量的操作,定义不同字符串表示不同操作

例子

矩阵转置

字符串为ij->ji,表示将张量中的元素ij(第i行第j列)变成元素ji

a = torch.tensor([
    [1, 2],
    [2, 3],
    [4, 5],
    [2, 6]
])
b = torch.einsum('ij->ji', a)
print(b)
>>> tensor([[1, 2, 4, 2],
            [2, 3, 5, 6]])
矩阵点乘(哈达玛积)

将矩阵对应位置的元素相乘,字符串ij,ij->ij表示将两个矩阵的ij元素相乘得到元素ij,貌似只能是相乘不能是相加?

a = torch.tensor([
    [1, 2],
    [2, 3],
    [4, 5],
    [2, 6]
])
b = torch.tensor([
    [2, 6],
    [2, 4],
    [3, 1],
    [2, 2]
])
c = torch.einsum('ij,ij->ij', a, b)
print(c)
>>> tensor([[ 2, 12],
	        [ 4, 12],
	        [12,  5],
	        [ 4, 12]])
矩阵乘法

将矩阵 a ∈ R M × N a\in R^{M\times N} aRM×N和矩阵 b ∈ R N × S b\in R^{N\times S} bRN×S相乘得到矩阵 c ∈ R M × S c\in R^{M\times S} cRM×S,字符串ik,kj->ij表示将矩阵a的ik元素和矩阵b的kj元素相乘,k是遍历变量,遍历所有k,累加ik*kj累加得到ij,即矩阵乘法

a = torch.tensor([
    [1, 2, 4],
    [2, 3, 5],
])
b = torch.tensor([
    [2, 6],
    [2, 4],
    [3, 1],
])
c = torch.einsum('ik,kj->ij', a, b)
print(c)
>>> tensor([[18, 18],
        	[25, 29]])

你可能感兴趣的:(pytorch)