einsum用于矩阵乘法
直接上例子吧
比如
'bhqd, bhkd -> bhqk'
虽然是4维,但是前两维是不变的,先不看,只看后2维,qd, kd -> qk
这是两个矩阵相乘,两个矩阵的shape分别为A=qxd, B=kxd, 得到的结果形状是C =qxk
根据矩阵乘法,我们知道(qxd) x (dxk)结果的形状为qxk,
也就是说上面相当于是AxBT=C
验证一下
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy.shape',energy.shape)
queries.shape torch.Size([1, 8, 197, 96])
key.shape torch.Size([1, 8, 197, 96])
energy.shape torch.Size([1, 8, 197, 197])
可以看到相当于queries x keysT, 即形状(197x96) x (197x96)T=(197x197)
再看一个
'bhal, bhlv -> bhav'
前两维一样的,不看,只看后两维,仍然看作是矩阵的形状A=axl, B=lxv
矩阵相乘(axl) x (lxv) = (axv),和结果的av相同
所以上面相当于是A与B相乘
验证一下
out = torch.einsum('bhal, bhlv -> bhav', att, values)
print('out.shape',out.shape)
att.shape torch.Size([1, 8, 197, 197])
values.shape torch.Size([1, 8, 197, 96])
out.shape torch.Size([1, 8, 197, 96])
可以看到相当于att x values,即形状(197x197) x (197x96) = (197x96)