TensorFlow patch块划分(transpose and reshape)

使用transformer处理图像数据,需要照特定格式对矩阵分块,并拉伸flatten,在完成最后的卷积后,需要重新将token的channel重新reshape成图像格式。类似下图,将输入首先分块,然后拉伸为NxC的vector,然后重新reshape为图像格式,这里使用一通道简要说明。

TensorFlow patch块划分(transpose and reshape)_第1张图片

 

小代码

def reshape():
    h = 6
    a = tf.random_uniform([h,h],maxval=40,dtype=tf.int32)
    b = tf.reshape(a,[2,3,2,3])
    c = tf.transpose(b,[0,2,1,3])
    d = tf.reshape(tf.reshape(c,[-1,3,3]),[4,-1])
    return a,d
def rereshape(x):
    a = tf.reshape(x,[2,2,3,3])
    b = tf.transpose(a,[0,2,1,3])
    c = tf.reshape(tf.reshape(b,[6,2,3]),[6,-1])
    return c
with tf.Session() as se:
    a,b = se.run(reshape())
    print('a:',a)
    print('b:',b)
    c = se.run(rereshape(b))
    print('c:',c)

输出

TensorFlow patch块划分(transpose and reshape)_第2张图片

你可能感兴趣的:(tensorflow,深度学习,机器学习)