最近遇到了需要将高于2维度的张量相乘的需求,通过互联网资源查到了先用tf.reshape()降到2维再运算的骚操作。下面验证这种操作的可靠性。
#测试多维矩阵乘法。问题来自于mul-attention模型的矩阵运算
#2019-7-14编辑
import tensorflow as tf
import numpy as np
#定义两个三维矩阵
#k.shape = [batch_size, seq_length, embedded_size]
#w.shape = [embedded_size, d_k, h_num]
k = tf.Variable(tf.random_uniform([3, 4, 5]))
w = tf.Variable(tf.random_uniform([5, 6, 7]))
#实现k与w的相乘,目标维度为[batch_size, seq_length, d_k, h_num]
#通过reshape的方式,将矩阵降到2维,实现矩阵乘法,再通过reshape还原
k_2d = tf.reshape(k, [-1, 5])
w_2d = tf.reshape(w, [5, -1])
r_2d = tf.matmul(k_2d, w_2d)
r_4d = tf.reshape(r_2d, [-1, 4, 6, 7])
#运算结果
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
r_untested = sess.run(r_4d)
k_3d, w_3d = sess.run([k, w])
print(np.dot(k_3d[0,:,:],w_3d[:,:,0]))
#array([[0.68616796, 1.147416 , 1.2250627 , 1.0267124 , 0.5699807 ,
0.65192497],
[1.2962847 , 0.63438064, 1.7439795 , 1.2534602 , 0.8585079 ,
0.9535629 ],
[1.0780972 , 1.466816 , 1.623834 , 1.4493611 , 0.9913111 ,
1.1141219 ],
[0.6155605 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317205 ,
0.8374078 ]], dtype=float32)
print(r_untested[0,:,:,0])
#array([[0.68616796, 1.147416 , 1.2250627 , 1.0267124 , 0.5699807 ,
0.65192497],
[1.2962846 , 0.6343807 , 1.7439795 , 1.2534602 , 0.8585079 ,
0.9535629 ],
[1.0780972 , 1.4668161 , 1.623834 , 1.449361 , 0.9913111 ,
1.1141219 ],
[0.6155604 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317204 ,
0.8374078 ]], dtype=float32)
最终发现结果相同,大胆的用吧