关于Transformer训练的问题

关于Transformer训练的问题_第1张图片
新建 Microsoft PowerPoint 演示文稿 (2).jpg

按照《The The Annotated Transformer》教程写下来,卡在了Greedy Decoding部分,回溯发现问题出在两个部分

  1. 在Embeddings部分,教程中如下
class Embeddings(nn.Module):
  def __init__(self, d_model, vocab):
   super(Embeddings, self).__init__()
   self.lut = nn.Embedding(vocab, d_model)
   self.d_model = d_model
  def forward(self, x):
   x = x.long()
   print(type(x))
   print(x)
 return self.lut(x) * math.sqrt(self.d_model)

会报错,提示要求输入的x为LongTensor,在这里我增加了

x = x.long()

解决

  1. 在LabelSmoothing部分,使用scatter_函数的时候提示。没搞定,这样模型训练应该是有问题的……

RuntimeError: Expected object of type torch.LongTensor but found type torch.IntTensor for argument #3 'index'

查了半天,应该是Pytorch新老版本的问题,现搁置起来,继续跑完项目所有的训练步骤。

如果有大神遇到这样的问题解决了,麻烦留言啊,我快疯了……

你可能感兴趣的:(关于Transformer训练的问题)