RuntimeError: expected scalar type Double but found Float

用pytorch训练神经网络时,自定义继承自Dataset的数据加载类,并送入DataLoader中。当从DataLoader获取data后,分别赋值给 inputs 和 labels,送入模型报错“RuntimeError: expected scalar type Double but found Float”,如下报错部分代码。

 for data in train_loader:
	inputs, labels = data
	output = model(inputs)
	loss = loss_fn(output, labels)

这是因为DataLoader中的数据类型是格式是 torch.float64,而训练网络使用的都是torch.float32类型,因此要进行类型转换。送入model前将 inputs 和 label 使用 .type(torch.FloatTensor) 进行转换,如下所示。

 for data in train_loader:
	inputs, labels = data
	inputs = inputs.type(torch.FloatTensor).to(device)
	labels = labels.type(torch.FloatTensor).to(device)
	output = model(inputs)
	loss = loss_fn(output, labels)

你可能感兴趣的:(深度学习,pytorch,人工智能)