基于torch函数TransformerEncoder出现AssertionError问题的解决

在使用transformer model时,由于存在encoder-decoder,encoder-only,decoder-only三种结构以应对不同的task。当我们使用encoder-only时,必然会涉及到TransformerEncoder和TransformerEncoderLayer函数的调用。
那么如下代码出现了AssertionError问题,应当如何解决?
基于torch函数TransformerEncoder出现AssertionError问题的解决_第1张图片
为什么会出现AssertionError(声明/断言)问题呢?可以看到,输入模型的第三维应该对应d_model这个参数,那么此处,这两个值应该一致。
基于torch函数TransformerEncoder出现AssertionError问题的解决_第2张图片
修改以后:
基于torch函数TransformerEncoder出现AssertionError问题的解决_第3张图片
运行得到:
在这里插入图片描述
其实我们发现,就transformer的编码器而言,输入输出的尺寸是一样的。

作于:
20215-9
21:50

你可能感兴趣的:(#,python遇到的坑)