RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the

1.问题描述

在训练网络模型时遇到了下面的问题:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

2.问题分析

出现这样的问题是由于我们输入的tensor在cuda(GPU)上,但是网络在CPU上。

3.解决办法

将待训练的网络加载到GPU上即可

device=torch.device('cuda')
model.to(device)

你可能感兴趣的:(小问题记录,深度学习,人工智能,python)