PyTorch中permute的用法

RuntimeError: Given groups=1, weight of size [18, 8, 8], expected input[64, 32, 8] to have 8 channels, but got 32 channels instead

问题分析:最近在自己框架中加入了注意力机制,由于quary的尺寸是(B,L,C),在输入注意力之前,经过conv 后 输出的尺寸是(B,C,L),所以需要使用permute()函数

permute(dims):将tensor的维度换位。

已知修改前 out的size为:(64,8,32) 利用out.permute(0, 2, 1) 得到一个size为 (64, 32, 8) 的 tensor。

out=out.permute(0, 2, 1)

你可能感兴趣的:(pytorch,人工智能,python)