哭了,复现TensorFlow版本MAE的shuffle和reshuffle

在encoder的输入需要非masked token,然后decoder的输入需要把对应位置的token用0代替进去,只想解决这个接口,所以解决目标就是按指定位置先取出对应的token,省略中间处理步骤,在按照index位置把非masked token塞回原大小矩阵。

废了2个小时,菜狗终于解决了这个问题(丢)

写了个小测试代码

import tensorflow as tf
import numpy as np
#值矩阵
target_tensor = tf.constant([[7, 2], [9, 6], [1, 3]])
#假装mask
mask_tensor2 = tf.constant([[1, 0], [0, 0], [1, 0]])
#把mask矩阵转为bool矩阵
mask = mask_tensor2>0
#为了方便直接flatten了
index = tf.where(tf.reshape(~mask,[-1]))
#为了对比结果用的
masked_tensor2 = tf.boolean_mask(target_tensor, ~mask)
#塞回去,为了省事没有reshape,直接在flatten上复原了
mm = tf.sparse_tensor_to_dense(tf.SparseTensor(indices=index, values=masked_tensor2, dense_shape=[6]))

输出结果如图示

哭了,复现TensorFlow版本MAE的shuffle和reshuffle_第1张图片

参考链接

https://www.jianshu.com/p/831cc6f5d810

你可能感兴趣的:(tensorflow,深度学习,机器学习)