第一次见到 rel_h = torch.einsum(“bhwc,hkc->bhwk”, r_q, Rh)这行代码时,属实是懵了,网上找了很多博主的介绍,但都没有详细的说明函数内部的计算过程,看得我是一头雾水,只知道计算结果的维度是如何变化的,却不明白函数内部是如何计算的。话不多说,直接上示例代码
import torch
r_q = torch.tensor([[[[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25],
[26, 27, 28, 29, 30],
[31, 32, 33, 34, 35],
[36, 37, 38, 39, 40]],
[[41, 42, 43, 44, 45],
[46, 47, 48, 49, 50],
[51, 52, 53, 54, 55],
[56, 57, 58, 59, 60]]]])
Rh = torch.tensor([[[1, 2, 3, 4, 5,],
[7, 8, 9, 10, 11, ],
[13, 14, 15, 16, 17, ],
[19, 20, 21, 22, 23, ],
[1, 2, 3, 4, 5,],
[1, 2, 3, 4, 5,],],
[[25, 26, 27, 28, 29, ],
[31, 32, 33, 34, 35, ],
[37, 38, 39, 40, 41, ],
[43, 44, 45, 46, 47, ],
[1, 2, 3, 4, 5,],
[1, 2, 3, 4, 5,],],
[[49, 50, 51, 52, 53, ],
[55, 56, 57, 58, 59, ],
[61, 62, 63, 64, 65, ],
[67, 68, 69, 70, 71, ],
[1, 2, 3, 4, 5,],
[1, 2, 3, 4, 5,],]])
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
print(rel_h)
文字很难解释清楚,直接上图。r_q的维度为(1, 3, 4, 5), Rh的维度为(3, 6, 5),函数torch.einsum(“bhwc,hkc->bhwk”, r_q, Rh)中b=1, h=3, w=4, c=5。所以最终结果Rel_h的维度为bhwk,即(1, 3, 4, 5)。具体计算过程如下图。
这回看懂了吧。还不理解的或者讲的不对的地方,欢迎在评论区留言。创作不易,喜欢的话点个关注吧