pytorch 基于masking对元素进行替换

描述

pytorch 基于masking对元素进行替换. 代码如下. 先展平再赋值.

代码

# map.shape [64,60,128]
# infill.shape [64,17,128]
# mask_indices.shape [64,60]
   map = map.reshape(
            map.shape[0] * map.shape[1],
            map.shape[2]) [mask_indices.reshape(mask_indices.shape[0]*mask_indices.shape[1])] \
            = fillin.reshape(fillin.shape[0]*fillin.shape[1], fillin.shape[2])
  

你可能感兴趣的:(pytorch,人工智能,python)