面对不同维度大小矩阵乘法操作的处理(Tensorflow)

遇到的问题:
面对矩阵的大小不同的两个矩阵,其中一个矩阵如何根据另一个矩阵的要求实现相应的行或列缩放。目标效果如下所示:

x:(2,2,3)
[[[ 1.  2.  3.],
  [ 4.  5.  6.]],
 [[ 7.  8.  9.],
  [10. 11. 12.]]]

w:(2,2)
[[0.5, 0.4],
   [0.1, 0.2]]

x*w:(2,2,3)
[[[0.5 1.  1.5]
  [1.6 2.  2.4]]
 [[0.7 0.8 0.9]
  [2.  2.2 2.4]]]

上面的效果,如果只利用点乘(w * x)乘法(tf.matmul(w, x))操作是无法完成的,需要利用到矩阵的维度变换。具体处理流程为:

  1. 对w进行维度扩张
    w(2,2) --> w(2,1,2)
  2. 将x的第二维和第三维变换
    x(2,2,3) -->x(2,3,2)
  3. 这时候再进行矩阵点乘操作,才能得到上面的效果。
    具体代码为:
# tensorflow的点乘
def test2():
    # a = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
    # [2, 2, 3] [2, 2]
    a1 = np.array([[[1.0,2.0,3.0],[4.0,5.0,6.0]],
                       [[7.0,8.0,9.0], [10.0,11.0,12.0]]])
    w = np.array([[0.5, 0.4],
                  [0.1, 0.2]])
    a1 = tf.convert_to_tensor(a1)
    w = tf.convert_to_tensor(w)
    # 
    # y = w * a1
    a_trans = tf.transpose(a1, [0, 2, 1])
    w = tf.expand_dims(w, 1)
    y = tf.multiply(a_trans, w)
    y = tf.transpose(y, [0, 2, 1])
    # y = a_trans*w
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print("x:")
        print(sess.run(a1))
        print("w")
        print(sess.run(w))
        print("x*w:")
        print(sess.run(y))
test2()

注意事项:
(1) 点乘,只有在w的列为1或与x的列相等时,才能进行点乘运算;
(2) 乘法, 只有前一个矩阵的最后一维和后面一个矩阵的第一维相等时,才能进行乘法操作;

打个小广告: 欢迎关注本人github: https://github.com/wuxiaoxiaoer
随时会有新想法,或技术更新,尤其是假新闻方面的研究。

你可能感兴趣的:(读博之路)