tf.tensordot用法

转载自链接

函数原型tf.tensordot(a, b, axes)
tensordot函数用来进行矩阵相乘,它的一个好处是:当a和b的维度不同时,也可以相乘。

举例:

1.

import tensorflow as tf
a = tf.ones(shape=[2,3,3])
b = tf.ones(shape=[3,2,6])
c = tf.tensordot(a,b, axes=1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #print(sess.run(c))
    print(sess.run(tf.shape(c)))

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

c的形状为[2,3,2,6],这里axes=1,说明取a的后1维即[3]和b的前1维即[3]进行矩阵相乘,其他维不变,那么根据矩阵乘法,自然得到c的大小为[2,3,2,6]

2.

import tensorflow as tf
a = tf.ones(shape=[2,2,3])
b = tf.ones(shape=[3,2,6])
c = tf.tensordot(a,b, axes=2)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #print(sess.run(c))
    print(sess.run(tf.shape(c)))

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

这里c的大小为[2,6],axes=2,即取a的后两维相乘后得到的2*3=6,与b的前两维相乘后得到的3*2=6,进行矩阵相乘运算,即大小为[2,6]的矩阵与大小为[6,6]的矩阵进行矩阵相乘运算,得到的c的大小即为[2,6]。如果a的后两维相乘不等于b的前两维相乘,比如a为[2,2,3],b为[3,3,6],则会报错。

3.

import tensorflow as tf
a = tf.ones(shape=[2,2,3])
b = tf.ones(shape=[3,2,6])
c = tf.tensordot(a,b, axes=(1,1))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #print(sess.run(c))
    print(sess.run(tf.shape(c)))

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

如果axes参数是一个元组,则元组的第一维指第一个乘数a要做运算的下标,第二维指第二个乘数要做运算的下标。这里axes=(1,1),也就是说a的第1维与b的第一维进行矩阵相乘。相当于[2,3,2]*[2,3,6],结果c即为[2,3,3,6]。

4.

import tensorflow as tf
a = tf.ones(shape=[2,2,3])
b = tf.ones(shape=[3,2,6])
c = tf.tensordot(a,b, axes=((1,2),(0,1)))
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #print(sess.run(c))
    print(sess.run(tf.shape(c)))

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

axes同样是元组,这里表明a的第1,2维和b的第0,1维进行矩阵乘法。即[2,2*3] * [3*2,6]= [2,6] * [6,6] = [2,6],c的大小即为[2,6]

你可能感兴趣的:(tf.tensordot用法)