解决RNN中的RuntimeError: expected scalar type Long but found Float报错

项目场景:

今天看了个RNN实例的代码,想用自己的数据试试传入RNN,奈何报错。

问题描述:

报错是这样的RuntimeError: expected scalar type Long but found Float


原因分析:

我报错的输入是这样的:

input=torch.tensor([ 0,  0,  0,  0,  0,  0,  0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0])

例子里的输入是这样的:

input=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0.]])

原因就是我生成输入的时候设置了dtype=torch.long


解决方案:

input=torch.tensor(input,dtype=torch.float)
前面生成输入张量的时候指定dtype=torch.float,这样得到的输入会是如下类型的

input=torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,0.]])

符合要求,不再报错。
没想到解决这点问题我用了大半个晚上,真的是还没入门,功底不够唉。

你可能感兴趣的:(rnn,python,pytorch)