Pytorch报错:RuntimeError: self must be a matrix

报错代码

Wh = torch.mm(h, self.w)

报错RuntimeError: self must be a matrix

原因:torch.mm()是两个矩阵相乘,即两个二维的张量相乘,维度超过二维,则会报错。
这两个tensor的维度是[16, 16, 29][29, 70]

>>> h.shape
torch.Size([16, 16, 29])
>>> self.w.shape
torch.Size([29, 70])

修改:使用torch.matmul()

Wh = torch.matmul(h, self.w)

>>>Wh.shape
torch.Size([16, 16, 70])

你可能感兴趣的:(机器学习,ML,pytorch,python,深度学习)