今天运行下面的代码时候,出现了这个问题,百度半天没有找到合适的解决办法,思考半天,终于用两行代码解决了,记一下!!
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2019/12/28 20:52
# @Author : LZQ
# @Software: PyCharm
import torch
x = torch.randn(5,4)
y=x.new_ones(5,4,dtype=float)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
x=x.to(device).cuda()
y=y.to(device).cuda()
z=x+y
print(z)
报错原因如下:
E:\lzqData\3d\anaconda3\Anaconda3\python.exe G:/pychram/workspace/python/study_pytorch/pytorch01/pytorch002.py
cuda:0
Traceback (most recent call last):
File "G:/pychram/workspace/python/study_pytorch/pytorch01/pytorch002.py", line 19, in
z=x+y
RuntimeError: expected device cuda:0 and dtype Double but got device cuda:0 and dtype Float
Process finished with exit code 1
原因是,我的x,y的类型是float不对,但是gpu想要的是double,把float改为double并不行,解决办法如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2019/12/28 20:52
# @Author : LZQ
# @Software: PyCharm
import torch
x = torch.randn(5,4)
y=x.new_ones(5,4,dtype=float)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
x=x.double() #为x添加double()函数
y = y.double()#为y添加double()函数
x=x.to(device).cuda()
y=y.to(device).cuda()
z=x+y
print(z)
结果:
E:\lzqData\3d\anaconda3\Anaconda3\python.exe G:/pychram/workspace/python/study_pytorch/pytorch01/pytorch002.py
cuda:0
tensor([[ 1.1086, 2.0426, 2.0449, -0.2139],
[ 0.1211, 0.6647, 1.4036, 3.2326],
[ 1.5524, 1.1472, 1.5518, 2.5580],
[ 2.8068, 2.0145, 2.6142, 1.5480],
[ 1.1827, 1.7224, 2.8095, -0.1689]], device='cuda:0',
dtype=torch.float64)
Process finished with exit code 0
问题分析:因为类型不匹配,我们要把x和y转换成我们想要的double,但是直接在dtype中改,是不支持的,所以我们强制转换。就搞定啦。爽歪歪!!!