运行torch.sum(torch.mul(users, pos_items), axis=1)时报错:
TypeError: sum() received an invalid combination of arguments - got (Tensor, axis=int), but expected one of:
* (Tensor input)
* (Tensor input, torch.dtype dtype)
didn't match because some of the keywords were incorrect: axis
* (Tensor input, tuple of ints dim, torch.dtype dtype, Tensor out)
* (Tensor input, tuple of ints dim, bool keepdim, torch.dtype dtype, Tensor out)
* (Tensor input, tuple of ints dim, bool keepdim, Tensor out)
其中,torch.mul函数的功能是两个维度相等的矩阵的对应位相乘,其中users和pos_items的大小都是:torch.Size([1024, 256])。
另外,torch.matmul是tensor的乘法,当输入是二维时和tensor.mm函数用法相同做普通的矩阵乘法,也能用作高维矩阵乘法。
按照提示,axis关键字错误,经查,torch中用dim,或者直接把axis关键字去掉,即改成:
torch.sum(torch.mul(users, pos_items), dim=1)
或者
torch.sum(torch.mul(users, pos_items), 1)