【numpy】tensordot的用法研究

API文档
https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html

文档说的过于晦涩,下面以实际例子来研究一下

样例1 axes=0 (二维为叉乘运算)

m1=np.array([0.1,0.2,0.3,
            0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
            0.1, 0.1,
            0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))

print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=0)

print("============================")
print(m3.shape)
print(m3)

运行结果
【numpy】tensordot的用法研究_第1张图片

实际计算过程是

  1. m1 reshape为(6,1) m2 reshape为(1,6)
  2. m1 dot m2 得到m3(6,6)
  3. 再把m3reshape为两个输入矩阵shape的串联(2,3,3,2)

更简单的理解就是,把任意维度的两个张量m1(m1s1,m1s2,…m1sn),m2(m2s1,m2s2,…m2sn)压“扁”为一个(N,1)和(1,N)矩阵,然后两个矩阵相乘,再把这个结果矩阵(N,N)做reshape,reshape的结果为两个原始张量shape的串联,即为(m1s1,m1s2,…m1sn,m2s1,m2s2,…m2sn)


样例2 axes=1 (二维为点乘运算)

m1=np.array([0.1,0.2,0.3,
            0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
            0.1, 0.1,
            0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))

print(m1.shape)
print(m2.shape)
# 等同于np.dot(m1,m2)
m3 = np.tensordot(m1,m2, axes=1)

print("============================")
print(m3.shape)
print(m3)

运行结果,它实际等同于np.dot(m1,m2)
【numpy】tensordot的用法研究_第2张图片

与axes=0的区别在于,它要求m1的最后一维与m2的第一位必须“一致”,这类似于二维矩阵的点乘,相乘后的两个维度“消失”


样例3 axes=2 (default)

m1=np.array([0.1,0.2,0.3,
            0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
            0.1, 0.1,
            0.2,0.2])
m1=m1.reshape((3,2))
m2=m2.reshape((3,2))

print(m1.shape)
print(m2.shape)
# 注意这里的shape是一样的
m3 = np.tensordot(m1,m2, axes=2)

print("============================")
print(m3.shape)
print(m3)

样例3结果

类似于axes=1,这里要求m1最后二个维度必须和m1的开始二个维度“一致”,其他维度任意。然后类似与做矩阵乘法,就是对axes=1做了+1扩展。
那么如果维度更高怎么控制呢? 这个就是后面采用tuple和list的用法了


样例4 axis为tuple

m1=np.array([0.1,0.2,0.3,
            0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
            0.1, 0.1,
            0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))

print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=(0,1))

print("============================")
print(m3.shape)
print(m3)

【numpy】tensordot的用法研究_第3张图片
这里的tuple含义是,tuple的第一个值是指定m1的第几个维度,第二个值指定m2的第几个维度。
如这里的(0,1),0表示m1的shape(2,3)中的第0个维度,2;1表示m2的shape(3,2)中的第1个维度,2;
注意,这两个指定的维度值必须相等,否则报错。
另一方面可见,这里tuple只能两个元素.

计算维度大于2的情况

m1.shape=(4,1,3,2)
m2.shape=(2,3,1,2)

举例:

  • 如果使用np.tensordot(m1,m2, axes=(-1,0))使用的就是m1最后一维2,与m2第0维2
    输出的结果shape就是除掉这两位的串联,(4,1,3,3,1,2)
    结果计算与之前的算法类似,首先是将m1转换为 (413,2)的矩阵,m2转换为(2,312)的矩阵,然后两者矩阵相乘,结果再reshape为(4,1,3,3,1,2)

  • 如果使用np.tensordot(m1,m2, axes=(2,1))使用的就是m1第3维3,与m2第1维3
    输出的结果shape就是除掉这两位的串联,(4,1,2,2,1,2)
    这里的计算稍微多一步:
    axes的重新拼接

    对于m1

    去掉拿掉的那一维,其他维度“凑紧”后在最后再补上这一维,见下图示例
    【numpy】tensordot的用法研究_第4张图片
    所以重新拼接后的axes为 (0,1,3,2)

    对于m2

【numpy】tensordot的用法研究_第5张图片
重新拼接后的axes为 (1,0,2,3)

拼接完成后,分别做transpose

m1.transpose(0,1,3,2)
m2.transpose(1,0,2,3)

然后m1 reshape为(4*1*2, 3), m2 reshape为(3,2*1*2)在做矩阵相乘。
相乘后的结果做reshape(4,1,2,2,1,2)


样例5 axis为list

m1=np.array([0.1,0.2,0.3,
            0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,
            0.1, 0.1,
            0.2,0.2])
m1=m1.reshape((2,3,1))
m2=m2.reshape((3,2,1))

print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=[(1,0),(0,1)])

print("============================")
print(m3.shape)
print(m3)

【numpy】tensordot的用法研究_第6张图片

计算方式同tuple,只不过每个张量维度的选择变成了多个,list中有两个tuple。第一个tuple中指定了m1的维度,第二个tuple中指定了m2的维度

你可能感兴趣的:(深度学习)