TensorFlow的 matmul 已经支持了batch,对于高维向量的相乘(比如两个三维矩阵),tensorflow把前面的维度当成是batch,对最后两维进行普通的矩阵乘法。也就是说,最后两维之前的维度,都需要相同。比如
A.shape=(a, b, n, m)
,B.shape=(a, b, m, k)
,tf.matmul(A,B) 的结果shape=(a,b,n,k)
有时候需要一个矩阵与多个矩阵相乘,也就是一个 2D Tensor 与一个 3D Tensor 相乘,比如
A.shape=(m, n)
,B.shape=(k, n, p)
,希望计算 A*B 得到一个C.shape=(k, m, p)
的 Tensor,可以采取的思路为:
可以看下面一个例子(c为标准答案,g为最后的正确结果,e是错误的):
import tensorflow as tf
a = tf.reshape(tf.linspace(1.,6.,6),[2,3])
b = tf.reshape(tf.linspace(1.,24.,24),[2,3,4])
c = tf.matmul(tf.tile(tf.expand_dims(a,0),multiples=[2,1,1]),b)
d = tf.matmul(a,tf.reshape(b,[3,2*4]))
e = tf.reshape(d,[2,2,4])
f = tf.transpose(b,[1,0,2])
g = tf.matmul(a,tf.reshape(f,[3,-1]))
g = tf.reshape(g,[2,2,4])
g = tf.transpose(g,[1,0,2])
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(b))
print('-------------')
print(sess.run(c))
print('-------------')
print(sess.run(d))
print(sess.run(e))
print('-------------')
print(sess.run(f))
print(sess.run(g))
结果:
[[1. 2. 3.]
[4. 5. 6.]]
[[[ 1. 2. 3. 4.]
[ 5. 6. 7. 8.]
[ 9. 10. 11. 12.]]
[[13. 14. 15. 16.]
[17. 18. 19. 20.]
[21. 22. 23. 24.]]]
-------------
[[[ 38. 44. 50. 56.]
[ 83. 98. 113. 128.]]
[[110. 116. 122. 128.]
[263. 278. 293. 308.]]]
-------------
[[ 70. 76. 82. 88. 94. 100. 106. 112.]
[151. 166. 181. 196. 211. 226. 241. 256.]]
[[[ 70. 76. 82. 88.]
[ 94. 100. 106. 112.]]
[[151. 166. 181. 196.]
[211. 226. 241. 256.]]]
-------------
[[[ 1. 2. 3. 4.]
[13. 14. 15. 16.]]
[[ 5. 6. 7. 8.]
[17. 18. 19. 20.]]
[[ 9. 10. 11. 12.]
[21. 22. 23. 24.]]]
[[[ 38. 44. 50. 56.]
[ 83. 98. 113. 128.]]
[[110. 116. 122. 128.]
[263. 278. 293. 308.]]]