IndexError: tensors used as indices must be long, byte or bool tensors

运行出现报错。修改数据格式

输出sample_ids的值,可以看到数据类型是 torch.int32

解决
需要将sample_ids类型转为long,修改方式:

idx= idx.type(torch.long)

idx= self.tensor(idx, dtype=torch.long)

参考:

IndexError: tensors used as indices must be long, byte or bool tensors
知乎:https://zhuanlan.zhihu.com/p/565931659

你可能感兴趣的:(深度学习,pytorch,人工智能)