关于运行torch.matmul(x,w)函数报错的问题 RuntimeError: expected scalar type Float but found Long

在李沐大佬**08 线性回归 + 基础优化算法【动手学深度学习v2】**一课中,代码实现环节

def synthetic_data(w,b, num_examples):
    # 生成y = Xw + b + 噪声
    X = torch.normal(0,1, (num_examples, len(w)))
    # x是均值为0,标准差为1的随机数,大小有num_examples个样本,列数为w
    y = torch.matmul(X, w) + b
    y += torch.normal(0,0.01,y.shape)
    # 加入一个均值为0,标准差为0.01 形状与y相同的噪音
    return X,y.reshape(-1,1)
    # -1 相当于去除那个维度,按照列向量进行返回
    
    
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b,3)

这里实现的功能是自己编写了一个噪声数据集函数,我对于表中的参数运算过程很好奇,所以对函数内部运算过程自行实现了一下

x = torch.normal(0,1, (3,2))
print(type(x))
print(x)
w = torch.tensor([2,-3])
w.reshape(-1,1)
print(w.shape)
y = torch.matmul(x, w) 
print(y)

但是此处会报错
关于运行torch.matmul(x,w)函数报错的问题 RuntimeError: expected scalar type Float but found Long_第1张图片
官方文档上定义张量的类型有如下几种:
关于运行torch.matmul(x,w)函数报错的问题 RuntimeError: expected scalar type Float but found Long_第2张图片
意外发现了一处官网的拼写错误 哈哈

之所以会报错是因为torch.normal(0,1,(3,2))所定义出来的3*2的数组,他的张量类型为float
然而torch.tensor([2,-3])在64位计算机上上定义出来的张量类型为long
所以torch.mutmul()函数在运行时会出现报错现象,这是很常见的数据类型不匹配的问题
将前后数据类型统一即可!
至于修改张量的数据类型
(1)在tensor后加long(),int(),double(),float(),byte()等函数就能将tensor进行类型转换。
例如:Torch.LongTensor转换为Torch.FloatTensor,直接使用data.float()即可。
(2)使用type()函数。
当data为tensor数据类型,如果使用data.type(torch.FloatTensor)则强制转换data为torch.FloatTensor类型张量。

修改后的结果:
关于运行torch.matmul(x,w)函数报错的问题 RuntimeError: expected scalar type Float but found Long_第3张图片

你可能感兴趣的:(深度学习,pytorch,深度学习,python)