tf.tensordot是tensorflow中tensor矩阵相乘的API,可以进行任意维度的矩阵相乘
(1).tf.tensordot函数详细介绍如下:
tf.tensordot(
a,
b,
axes,
name=None
)
"""
Args:
a:类型为float32或者float64的tensor
b:和a有相同的type,即张量同类型,但不要求同维度
axes:可以为int32,也可以是list,为int32,表示取a的最后几个维度,与b的前面几个维度相乘,再累加求和,消去(收缩)相乘维度
为list,则是指定a的哪几个维度与b的哪几个维度相乘,消去(收缩)这些相乘的维度
name:操作命名
"""
(2).代码演示(举四维Tensor与三维Tensor相乘的例子)
1.获取一个shape=(2,1,3,2)的随机数矩阵a,以及一个shape=(2,3,1)的矩阵b
import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
with tf.Session() as sess:
print("a的shape:",a.shape)
print("b的shape:",b.shape)
print("a的值:",sess.run(a))
print("b的值:",sess.run(b))
显示结果:
a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
a的值: [[[[0 1]
[2 1]
[3 4]]]
[[[5 2]
[3 4]
[5 0]]]]
b的值: [[[1]
[3]
[2]]
[[3]
[1]
[2]]]
2.在axes=1轴上相乘指的是将a的最后一个维度与b的第一个维度矩阵相乘,然后将结果累加求和,消除(收缩)这两个维度,矩阵a,b剩下的维度concat,就是所求矩阵维度
import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
res = tf.tensordot(a,b,axes=1)
with tf.Session() as sess:
print("a的shape:",a.shape)
print("b的shape:",b.shape)
print("res_shape:",res.shape)
print("res_value:",sess.run(res))
显示结果:
a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
res_shape: (2, 1, 3, 3, 1)
res_value: [[[[[ 3]
[ 1]
[ 2]]
[[ 5]
[ 7]
[ 6]]
[[15]
[13]
[14]]]]
[[[[11]
[17]
[14]]
[[15]
[13]
[14]]
[[ 5]
[15]
[10]]]]]
3.在axes=2轴上相乘指的是将a的最后两个维度与b的前两个维度矩阵相乘,然后将结果累加求和,消除(收缩)这四个维度,矩阵a,b剩下的维度concat,就是所求矩阵维度
import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
res = tf.tensordot(a,b,axes=2)
with tf.Session() as sess:
print("a的shape:",a.shape)
print("b的shape:",b.shape)
print("res_shape:",res.shape)
print("res_value:",sess.run(res))
显示结果:
a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
res_shape: (2, 1, 1)
res_value: [[[21]]
[[34]]]
4.在axes=[[1,3],[0,2]]上进行tensor相乘,指的是将a的第一个维度、第三个维度concat的维度与b的第0(维度下标从0开始)个维度、第二个维度concat的维度进行矩阵相乘,然后将结果累加求和,消除(收缩)这四个维度,矩阵a,b剩下的维度concat,就是所求矩阵维度
import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
res = tf.tensordot(a,b,axes=[[1,3],[0,2]])
with tf.Session() as sess:
print("a的shape:",a.shape)
print("b的shape:",b.shape)
print("res_shape:",res.shape)
print("res_value:",sess.run(res))
显示结果:
a的shape: (2, 1, 3, 2)
b的shape: (2, 3, 1)
res_shape: (2, 3, 3)
res_value: [[[ 3 1 2]
[ 5 7 6]
[15 13 14]]
[[11 17 14]
[15 13 14]
[ 5 15 10]]]