PYTORCH torch.einsum 函数

以爱因斯坦求和的形式任意定义想要进行的矩阵乘法的操作,可以内部指定输出转置,功能多样,形式灵活。

#theta_phi: nxtxg
#g: nxcxg
output = torch.einsum('ntg, ncg->nct', theta_phi, g)
#output: nxcxt

你可能感兴趣的:(pytorch)