Conditional DETR spatial attention & content attention可视化(二)

就是将attention图通过加权叠加 叠加到原图上

要通过cv2.applyColorMap() 将attention的单通道图转为三通道图

将attention中一些小的值置0,不然叠加之后会干扰原图,产生色差

至于蓝色,是通过cv2.applyColorMap的cv2.COLORMAP_DEEPGREEN,这个是可以变成绿色,但不知为什么只有deepgreen, 没有deepblue, deepred这些,所以采取的方法是将G这个通道的值转为B通道的值

# #------------------------------------------------------------#
# 可视化Detr方法:
# spatial attention weight : (cq + oq)*pk
# combined attention weight: (cq + oq)*(memory + pk)
# 其中:
#     pk:原始特征图的位置编码;
#     oq:训练好的object queries
#     cq:decoder最后一层self-attn中的输出query
#     memory:encoder的输出
# #------------------------------------------------------------#
# 在此基础上只要稍微修改便可可视化ConditionalDetr的Fig1特征图
# #------------------------------------------------------------#
# 代码参考自:https://github.com/facebookresearch/detr/tree/colab
# #------------------------------------------------------------#
import skimage
import math
import numpy as np
import cv2
from PIL import Image
import requests
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output
import cmapy
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
from torch.nn.functional import dropout,linear,softmax
import torch.nn.functional as F
torch.set_grad_enabled(False)
import matplotlib

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载线上的模型
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()

# 获取训练好的参数
for name, parameters in model.named_parameters():
    # 获取训练好的object queries,即pq:[100,256]
    if name == 'query_embed.weight':
        pq = parameters  #q_position
    # 获取解码器的最后一层的交叉注意力模块中q和k的线性权重和偏置:[256*3,256],[768]
    if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_weight':
        in_proj_weight = parameters
    if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_bias':
        in_proj_bias = parameters

# 线上下载图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)

# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)

# propagate through the model
outputs = model(img)

# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9

# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

# use lists to store the outputs via up-values
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
cq = []     # 存储detr中的 cq(q_content)
pk =  []    # 存储detr中的 encoder pos(k_position)
memory = [] # 存储encoder的输出特征图memory

# 注册hook
hooks = [
    # 获取resnet最后一层特征图
    model.backbone[-2].register_forward_hook(
        lambda self, input, output: conv_features.append(output)
    ),
    # 获取encoder的图像特征图memory
    model.transformer.encoder.register_forward_hook(
        lambda self, input, output: memory.append(output)
    ),
    # 获取encoder的最后一层layer的self-attn weights
    model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
        lambda self, input, output: enc_attn_weights.append(output[1])
    ),
    # 获取decoder的最后一层layer中交叉注意力的 weights
    model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
        lambda self, input, output: dec_attn_weights.append(output[1])
    ),
    # 获取decoder最后一层self-attn的输出cq
    model.transformer.decoder.layers[-1].norm1.register_forward_hook(
        lambda self, input, output: cq.append(output) #
    ),
    # 获取图像特征图的位置编码pk
    model.backbone[-1].register_forward_hook(
        lambda self, input, output: pk.append(output)
    ),
]

# propagate through the model
outputs = model(img)

# 用完的hook后删除
for hook in hooks:
    hook.remove()

# don't need the list anymore
conv_features = conv_features[0]       # [1,2048,25,34]
enc_attn_weights = enc_attn_weights[0] # [1,850,850]   : [N,L,S]
dec_attn_weights = dec_attn_weights[0] # [1,100,850]   : [N,L,S] --> [batch, tgt_len, src_len]
memory = memory[0] # [850,1,256]

cq = cq[0]    # decoder的self_attn:最后一层输出[100,1,256]
pk = pk[0]    # [1,256,25,34]

# 绘制postion embedding
pk = pk.flatten(-2).permute(2,0,1)           # [1,256,850] --> [850,1,256]
pq = pq.unsqueeze(1).repeat(1,1,1)           # [100,1,256]
q = pq + cq
#------------------------------------------------------#
#   1) k = pk,则可视化: (cq + oq)*pk
#   2_ k = pk + memory,则可视化 (cq + oq)*(memory + pk)
#   读者可自行尝试
#------------------------------------------------------#
k = pk
# k = memory
# k = pk + memory   #这就是正常,完整的k了
#------------------------------------------------------#

# 将q和k完成线性层的映射,代码参考自nn.MultiHeadAttn()
_b = in_proj_bias
_start = 0
_end = 256
_w = in_proj_weight[_start:_end, :]
if _b is not None:
    _b = _b[_start:_end]
q = F.linear(q, _w, _b)

_b = in_proj_bias
_start = 256
_end = 256 * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
    _b = _b[_start:_end]
k = F.linear(k, _w, _b)

scaling = float(256) ** -0.5
q = q * scaling
q = q.contiguous().view(100, 8, 32).transpose(0, 1)
k = k.contiguous().view(-1, 8, 32).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))

attn_output_weights = attn_output_weights.view(1, 8, 100, 850)
attn_output_weights = attn_output_weights.view(1 * 8, 100, 850)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = attn_output_weights.view(1, 8, 100, 850)

# 后续可视化各个头
attn_every_heads = attn_output_weights # [1,8,100,850]
attn_output_weights = attn_output_weights.sum(dim=1) / 8 # [1,100,850]

#-----------#
#   可视化
#-----------#
# get the feature map shape
h, w = conv_features['0'].tensors.shape[-2:]

fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=10, figsize=(22, 28))  # [11,2]
colors = COLORS * 100

# 可视化
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):
    # 可视化decoder的注意力权重
    ax = ax_i[0]
    ax.imshow(dec_attn_weights[0, idx].view(h, w))
    ax.axis('off')
    ax.set_title(f'query id: {idx.item()}',fontsize = 30)
    # 可视化框和类别
    ax = ax_i[1]
    ax.imshow(im)
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                               fill=False, color='blue', linewidth=3))
    ax.axis('off')
    ax.set_title(CLASSES[probas[idx].argmax()], fontsize = 30)
    # 分别可视化8个头部的位置特征图
    for head in range(2, 2 + 8):
        ax = ax_i[head]
        img = cv2.cvtColor(np.asarray(im), cv2.COLOR_RGB2BGR)
        heatmap = attn_every_heads[0, head-2, idx].view(h,w)
        #将一些小的权重值置为0,以免后面合成的时候干扰原图
        heatmap.masked_fill_((heatmap < torch.mean(heatmap)), 0.)
        heatmap = heatmap.numpy()
        
        heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
        heatmap = np.uint8(255 * 100 *heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_DEEPGREEN)  #[h,w,3]
        
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR)
        heatmap = heatmap[:,:,(1,0,2)]

        # overlapping = img
        overlapping = cv2.addWeighted(img, 1.0, heatmap, 0.8, 0)
        overlapping = cv2.cvtColor(overlapping, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(overlapping)
        
        ax.imshow(img)
        
        ax.axis('off')
        ax.set_title(f'head:{head-2}',fontsize = 30)
        
fig.tight_layout()        # 自动调整子图来使其填充整个画布
plt.show()
plt.savefig('spatial_attn.jpg')

Conditional DETR spatial attention & content attention可视化(二)_第1张图片

八成能达到论文中的样子

还差一些,目前不知道该怎么做了,希望未来能改善吧

你可能感兴趣的:(python,深度学习,pytorch)