torch.cdist高效计算大矩阵相似度

问题定义

现有矩阵 A ∈ R N × C , B ∈ R M × C A\in R^{N\times C}, B\in R^{M\times C} ARN×C,BRM×C,需要计算矩阵 A A A B B B的相似度(欧式距离)矩阵 S ∈ R N × M S\in R^{N\times M} SRN×M N N N M M M很大。可以使用pytorch提供的torch.cdist方法,记得使用GPU计算。

import torch

N, M, C = 20000, 50000, 128
A = torch.rand((N, C)).cuda()
B = torch.rand((M, C)).cuda()

S = torch.cdist(A, B, p=2)
print(S.shape)

你可能感兴趣的:(Python,矩阵,pytorch,深度学习)