API文档
https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html
文档说的过于晦涩,下面以实际例子来研究一下
m1=np.array([0.1,0.2,0.3,
0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
0.1, 0.1,
0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))
print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=0)
print("============================")
print(m3.shape)
print(m3)
实际计算过程是
更简单的理解就是,把任意维度的两个张量m1(m1s1,m1s2,…m1sn),m2(m2s1,m2s2,…m2sn)压“扁”为一个(N,1)和(1,N)矩阵,然后两个矩阵相乘,再把这个结果矩阵(N,N)做reshape,reshape的结果为两个原始张量shape的串联,即为(m1s1,m1s2,…m1sn,m2s1,m2s2,…m2sn)
m1=np.array([0.1,0.2,0.3,
0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
0.1, 0.1,
0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))
print(m1.shape)
print(m2.shape)
# 等同于np.dot(m1,m2)
m3 = np.tensordot(m1,m2, axes=1)
print("============================")
print(m3.shape)
print(m3)
与axes=0的区别在于,它要求m1的最后一维与m2的第一位必须“一致”,这类似于二维矩阵的点乘,相乘后的两个维度“消失”
m1=np.array([0.1,0.2,0.3,
0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
0.1, 0.1,
0.2,0.2])
m1=m1.reshape((3,2))
m2=m2.reshape((3,2))
print(m1.shape)
print(m2.shape)
# 注意这里的shape是一样的
m3 = np.tensordot(m1,m2, axes=2)
print("============================")
print(m3.shape)
print(m3)
类似于axes=1,这里要求m1最后二个维度必须和m1的开始二个维度“一致”,其他维度任意。然后类似与做矩阵乘法,就是对axes=1做了+1扩展。
那么如果维度更高怎么控制呢? 这个就是后面采用tuple和list的用法了
m1=np.array([0.1,0.2,0.3,
0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
0.1, 0.1,
0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))
print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=(0,1))
print("============================")
print(m3.shape)
print(m3)
这里的tuple含义是,tuple的第一个值是指定m1的第几个维度,第二个值指定m2的第几个维度。
如这里的(0,1),0表示m1的shape(2,3)中的第0个维度,2;1表示m2的shape(3,2)中的第1个维度,2;
注意,这两个指定的维度值必须相等,否则报错。
另一方面可见,这里tuple只能两个元素.
如
m1.shape=(4,1,3,2)
m2.shape=(2,3,1,2)
举例:
如果使用np.tensordot(m1,m2, axes=(-1,0))
使用的就是m1最后一维2,与m2第0维2
输出的结果shape就是除掉这两位的串联,(4,1,3,3,1,2)
结果计算与之前的算法类似,首先是将m1转换为 (413,2)的矩阵,m2转换为(2,312)的矩阵,然后两者矩阵相乘,结果再reshape为(4,1,3,3,1,2)
如果使用np.tensordot(m1,m2, axes=(2,1))
使用的就是m1第3维3,与m2第1维3
输出的结果shape就是除掉这两位的串联,(4,1,2,2,1,2)
这里的计算稍微多一步:
axes的重新拼接
去掉拿掉的那一维,其他维度“凑紧”后在最后再补上这一维,见下图示例
所以重新拼接后的axes为 (0,1,3,2)
拼接完成后,分别做transpose
m1.transpose(0,1,3,2)
m2.transpose(1,0,2,3)
然后m1 reshape为(4*1*2, 3), m2 reshape为(3,2*1*2)在做矩阵相乘。
相乘后的结果做reshape(4,1,2,2,1,2)
m1=np.array([0.1,0.2,0.3,
0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
0.1, 0.1,
0.2,0.2])
m1=m1.reshape((2,3,1))
m2=m2.reshape((3,2,1))
print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=[(1,0),(0,1)])
print("============================")
print(m3.shape)
print(m3)
计算方式同tuple,只不过每个张量维度的选择变成了多个,list中有两个tuple。第一个tuple中指定了m1的维度,第二个tuple中指定了m2的维度