tensorflow高维度张量相乘

高维度张量相乘

  • 通过tf.reshape()对高维度张量降维,验证高维度张量相乘结果

通过tf.reshape()对高维度张量降维,验证高维度张量相乘结果

最近遇到了需要将高于2维度的张量相乘的需求,通过互联网资源查到了先用tf.reshape()降到2维再运算的骚操作。下面验证这种操作的可靠性。

#测试多维矩阵乘法。问题来自于mul-attention模型的矩阵运算
#2019-7-14编辑
import tensorflow as tf
import numpy as np

#定义两个三维矩阵
#k.shape = [batch_size, seq_length, embedded_size]
#w.shape = [embedded_size, d_k, h_num]
k = tf.Variable(tf.random_uniform([3, 4, 5]))
w = tf.Variable(tf.random_uniform([5, 6, 7]))

#实现k与w的相乘,目标维度为[batch_size, seq_length, d_k, h_num]
#通过reshape的方式,将矩阵降到2维,实现矩阵乘法,再通过reshape还原
k_2d = tf.reshape(k, [-1, 5])
w_2d = tf.reshape(w, [5, -1])
r_2d = tf.matmul(k_2d, w_2d)
r_4d = tf.reshape(r_2d, [-1, 4, 6, 7])

#运算结果
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    r_untested = sess.run(r_4d)
    k_3d, w_3d = sess.run([k, w])

print(np.dot(k_3d[0,:,:],w_3d[:,:,0]))
#array([[0.68616796, 1.147416  , 1.2250627 , 1.0267124 , 0.5699807 ,
        0.65192497],
       [1.2962847 , 0.63438064, 1.7439795 , 1.2534602 , 0.8585079 ,
        0.9535629 ],
       [1.0780972 , 1.466816  , 1.623834  , 1.4493611 , 0.9913111 ,
        1.1141219 ],
       [0.6155605 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317205 ,
        0.8374078 ]], dtype=float32)
print(r_untested[0,:,:,0])
#array([[0.68616796, 1.147416  , 1.2250627 , 1.0267124 , 0.5699807 ,
        0.65192497],
       [1.2962846 , 0.6343807 , 1.7439795 , 1.2534602 , 0.8585079 ,
        0.9535629 ],
       [1.0780972 , 1.4668161 , 1.623834  , 1.449361  , 0.9913111 ,
        1.1141219 ],
       [0.6155604 , 1.0016347 , 0.95043844, 0.8071648 , 0.6317204 ,
        0.8374078 ]], dtype=float32)

最终发现结果相同,大胆的用吧

你可能感兴趣的:(tensorflow高维度张量相乘)