torch.bmm函数讲解

torch.bmm 是 PyTorch 中的一个函数,用于执行批矩阵乘法(batch matrix multiplication)操作。

它的输入是三维张量,形状为 (batch_size, n, m) 和 (batch_size, m, p):
其中 n 是第一个矩阵的列数,m 是两个矩阵共享的维度,p 是第二个矩阵的列数。

torch.bmm 将批中的每对矩阵相乘,返回一个新的三维张量,形状为 (batch_size, n, p)。

例如,假设我们有两个批次的矩阵 A 和 B,维度分别为 (2, 3, 4) 和 (2, 4, 5)。我们可以使用 torch.bmm 将它们相乘
torch.bmm函数讲解_第1张图片
torch.bmm函数讲解_第2张图片
torch.bmm函数讲解_第3张图片

你可能感兴趣的:(Pytorch,pytorch)