Python torch.linalg.multi_dot

Python torch.linalg.multi_dot用法及代码示例

torch.linalg.multi_dot(tensors, *, out=None)


>>> A = torch.arange(2 * 3).view(2, 3)
>>> B = torch.arange(3 * 2).view(3, 2)
>>> C = torch.arange(2 * 2).view(2, 2)
>>> multi_dot((A, B, C))
tensor([[ 26,  49],
        [ 80, 148]])

参数:
tensors(Sequence[Tensor]) -两个或多个张量相乘。第一个和最后一个张量可以是 1D 或 2D。每个其他张量必须是 2D 的。

关键字参数:
out(Tensor,可选的) -输出张量。如果 None 则忽略。默认值:None。

通过重新排序乘法来有效地将两个或多个矩阵相乘,以便执行最少的算术运算。

支持 float、double、cfloat 和 cdouble dtypes 的输入。此函数不支持批量输入。

tensors 中的每个张量都必须是 2D,但第一个和最后一个可能是 1D 的除外。如果第一个张量是形状为 (n,) 的一维向量,则将其视为形状为 (1, n) 的行向量,类似地,如果最后一个张量是形状为 (n,) 的一维向量,则将其视为形状为列向量(n, 1) 。

如果第一个和最后一个张量是矩阵,则输出将是一个矩阵。但是,如果其中任何一个是一维向量,则输出将是一维向量。

与 numpy.linalg.multi_dot 的区别:

与 numpy.linalg.multi_dot 不同,第一个和最后一个张量必须是 1D 或 2D,而 NumPy 允许它们是 nD

警告

此函数不广播。

注意

此函数通过在计算最佳矩阵乘法顺序后链接 torch.mm() 调用来实现。

以上就是全部内容

四级标题

五级标题
六级标题

你可能感兴趣的:(python,python,numpy,开发语言)