从DETR到Mask2Former(3):masked attention的attention map可视化

Mask2Former的论文中有这样一张图,表示masked attenion比cross attention效果要好

从DETR到Mask2Former(3):masked attention的attention map可视化_第1张图片

那么这个attention map是怎么画出来的?

在mask2attention的源代码中 CrossAttentionLayer这个类中,在forward_post函数中做如下修改:

    def forward_post(self, tgt, memory,
                     memory_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        tgt2, atten_weight = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask, average_attn_weights=False)
        
        atten_weight = atten_weight.squeeze().detach().cpu().numpy()
        head_num = 0
        selected_query_num = 0
        if atten_weight.shape[-1] == 21888:
            import matplotlib.pyplot as plt
            # 创建2行4列的图形
            fig, axs = plt.subplots(2, 4, figsize=(12, 6))
            
            # 使用8次for循环在每个子图中进行绘制
            for i in range(2):
                for j in range(4):
                    atten_map = atten_weight[head_num, selected_query_num, :]
                    atten_map = atten_map.reshape((128, 171))
                    
                    head_num += 0
                    
                    axs[i, j].imshow(atten_map)
            plt.show()
        
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)
        
        return tgt

在 nn.MultiheadAttention 类实例的forward方法中,加入

average_attn_weights=False

得到每个注意力头的attention map,将attention_weight可视化,就得到了论文中的图片。

你可能感兴趣的:(人工智能)