举个栗子:
import numpy as np
a1 = np.array([[1, 2], [3, 5]])
a2 = np.array([[2, 5], [3, 1]])
A = np.stack([a1, a2], axis=0)
A.shape, a1.shape
>> ((2, 2, 2), (2, 2))
b1 = np.array([[1, 2, 3], [1, 3, 5]])
b2 = np.array([[2, 2, 4], [5, 3, 2]])
B = np.stack([a1, a2], axis=0)
B.shape, b1.shape
>> ((2, 2, 3), (2, 3))
print("a1 matmul b1 \n", np.matmul(a1, b1), '\n')
print("a1 matmul b2 \n", np.matmul(a1, b2), '\n')
print("a2 matmul b1 \n", np.matmul(a2, b1), '\n')
print("a2 matmul b2 \n", np.matmul(a2, b2), '\n')
print("A matmul B \n", np.matmul(A, B))
print(f"np.matmul(a1, b1) == np.matmul(A, B)[0] {(np.matmul(a1, b1) == np.matmul(A, B)[0]).all()}")
print(f"np.matmul(a2, b2) == np.matmul(A, B)[1] {(np.matmul(a2, b2) == np.matmul(A, B)[1]).all()}")
>> a1 matmul b1
[[ 3 8 13]
[ 8 21 34]]
>> a1 matmul b2
[[12 8 8]
[31 21 22]]
>> a2 matmul b1
[[ 7 19 31]
[ 4 9 14]]
>> a2 matmul b2
[[29 19 18]
[11 9 14]]
>> A matmul B
[[[ 3 8 13]
[ 8 21 34]]
[[29 19 18]
[11 9 14]]]
>> np.matmul(a1, b1) == np.matmul(A, B)[0] True
>> np.matmul(a2, b2) == np.matmul(A, B)[1] True
可见,对于高维矩阵的相乘实质上是对高维矩阵中每个二维矩阵相乘。值得注意的是,这里的相乘并非笛卡尔积,而是对应位置的矩阵相乘,因此不会改变除了最后两维的矩阵维度(如果利用广播机制除外)。