tensorflow之tf.tensordot详解

tf.tensordot是tensorflow中tensor矩阵相乘的API,可以进行任意维度的矩阵相乘

(1).tf.tensordot函数详细介绍如下:

tf.tensordot(
    a,
    b,
    axes,
    name=None
)
"""
Args:
    a:类型为float32或者float64的tensor
    b:和a有相同的type,即张量同类型,但不要求同维度
    axes:可以为int32,也可以是list,为int32,表示取a的最后几个维度,与b的前面几个维度相乘,再累加求和,消去(收缩)相乘维度
        为list,则是指定a的哪几个维度与b的哪几个维度相乘,消去(收缩)这些相乘的维度
    name:操作命名
"""

(2).代码演示(举四维Tensor与三维Tensor相乘的例子)

1.获取一个shape=(2,1,3,2)的随机数矩阵a,以及一个shape=(2,3,1)的矩阵b

import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
with tf.Session() as sess:
    print("a的shape:",a.shape)
    print("b的shape:",b.shape)
    print("a的值:",sess.run(a))
    print("b的值:",sess.run(b))

显示结果:

a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
a的值: [[[[0 1]
       [2 1]
       [3 4]]]
     [[[5 2]
       [3 4]
       [5 0]]]]
b的值: [[[1]
        [3]
        [2]]
       [[3]
        [1]
       [2]]]

2.在axes=1轴上相乘指的是将a的最后一个维度与b的第一个维度矩阵相乘,然后将结果累加求和,消除(收缩)这两个维度,矩阵a,b剩下的维度concat,就是所求矩阵维度

import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
res = tf.tensordot(a,b,axes=1)
with tf.Session() as sess:
        print("a的shape:",a.shape)
  	    print("b的shape:",b.shape)
        print("res_shape:",res.shape)
        print("res_value:",sess.run(res))

显示结果:

a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
res_shape: (2, 1, 3, 3, 1)
res_value: [[[[[ 3]
    [ 1]
    [ 2]]
   [[ 5]
    [ 7]
    [ 6]]
   [[15]
    [13]
    [14]]]]
 [[[[11]
    [17]
    [14]]
   [[15]
    [13]
    [14]]
   [[ 5]
    [15]
    [10]]]]]

3.在axes=2轴上相乘指的是将a的最后两个维度与b的前两个维度矩阵相乘,然后将结果累加求和,消除(收缩)这四个维度,矩阵a,b剩下的维度concat,就是所求矩阵维度

import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
res = tf.tensordot(a,b,axes=2)
with tf.Session() as sess:
        print("a的shape:",a.shape)
  	    print("b的shape:",b.shape)
        print("res_shape:",res.shape)
        print("res_value:",sess.run(res))

显示结果:

a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
res_shape: (2, 1, 1)
res_value: [[[21]]
	    [[34]]]

4.在axes=[[1,3],[0,2]]上进行tensor相乘,指的是将a的第一个维度、第三个维度concat的维度与b的第0(维度下标从0开始)个维度、第二个维度concat的维度进行矩阵相乘,然后将结果累加求和,消除(收缩)这四个维度,矩阵a,b剩下的维度concat,就是所求矩阵维度

import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
res = tf.tensordot(a,b,axes=[[1,3],[0,2]])
with tf.Session() as sess:
        print("a的shape:",a.shape)
  	    print("b的shape:",b.shape)
        print("res_shape:",res.shape)
        print("res_value:",sess.run(res))

显示结果:

a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
res_shape: (2, 3, 3)
res_value: [[[ 3  1  2]
	     [ 5  7  6]
	      [15 13 14]]
	     [[11 17 14]
	      [15 13 14]
	      [ 5 15 10]]]

你可能感兴趣的:(tensorflow,tf.tensordot,tensor相乘)