Detr空间注意力热图及语义注意力热图可视化

文章目录

  • 前言
  • 完整代码
  • GitHub链接
  • 附录:可视化resnet50最后一层特征图


前言

 之前一直比较好奇Conditional Detr中如何可视化各个头部的空间注意力热图的,于是,本人尝试在Detr基础上实现了一个demo,可以无脑化运行,先上最终的效果图:
 代码中我已经加了详细注释,文末有GitHub链接。
Detr空间注意力热图及语义注意力热图可视化_第1张图片

完整代码

# #------------------------------------------------------------#
# 可视化Detr方法:
# spatial attention weight : (cq + oq)*pk
# combined attention weight: (cq + oq)*(memory + pk)
# 其中:
#     pk:原始特征图的位置编码;
#     oq:训练好的object queries
#     cq:decoder最后一层self-attn中的输出query
#     memory:encoder的输出
# #------------------------------------------------------------#
# 在此基础上只要稍微修改便可可视化ConditionalDetr的Fig1特征图
# #------------------------------------------------------------#
# 代码参考自:https://github.com/facebookresearch/detr/tree/colab
# #------------------------------------------------------------#

import math
import numpy as np

from PIL import Image
import requests
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display, clear_output

import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
from torch.nn.functional import dropout,linear,softmax
torch.set_grad_enabled(False)


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


# 加载线上的模型
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()
# 获取训练好的参数
for name, parameters in model.named_parameters():
    # 获取训练好的object queries,即pq:[100,256]
    if name == 'query_embed.weight':
        pq = parameters
    # 获取解码器的最后一层的交叉注意力模块中q和k的线性权重和偏置:[256*3,256],[768]
    if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_weight':
        in_proj_weight = parameters
    if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_bias':
        in_proj_bias = parameters
# 线上下载图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
# img_path = '/home/wujian/000000039769.jpg'
# im = Image.open(img_path)

# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)

# propagate through the model
outputs = model(img)

# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9

# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

# use lists to store the outputs via up-values
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
cq = []     # 存储detr中的 cq
pk =  []    # 存储detr中的 encoder pos
memory = [] # 存储encoder的输出特征图memory

# 注册hook
hooks = [
    # 获取resnet最后一层特征图
    model.backbone[-2].register_forward_hook(
        lambda self, input, output: conv_features.append(output)
    ),
    # 获取encoder的图像特征图memory
    model.transformer.encoder.register_forward_hook(
        lambda self, input, output: memory.append(output)
    ),
    # 获取encoder的最后一层layer的self-attn weights
    model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
        lambda self, input, output: enc_attn_weights.append(output[1])
    ),
    # 获取decoder的最后一层layer中交叉注意力的 weights
    model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
        lambda self, input, output: dec_attn_weights.append(output[1])
    ),
    # 获取decoder最后一层self-attn的输出cq
    model.transformer.decoder.layers[-1].norm1.register_forward_hook(
        lambda self, input, output: cq.append(output)
    ),
    # 获取图像特征图的位置编码pk
    model.backbone[-1].register_forward_hook(
        lambda self, input, output: pk.append(output)
    ),
]

# propagate through the model
outputs = model(img)

# 用完的hook后删除
for hook in hooks:
    hook.remove()

# don't need the list anymore
conv_features = conv_features[0]       # [1,2048,25,34]
enc_attn_weights = enc_attn_weights[0] # [1,850,850]   : [N,L,S]
dec_attn_weights = dec_attn_weights[0] # [1,100,850]   : [N,L,S] --> [batch, tgt_len, src_len]
memory = memory[0] # [850,1,256]

cq = cq[0]    # decoder的self_attn:最后一层输出[100,1,256]
pk = pk[0]    # [1,256,25,34]

# 绘制postion embedding
pk = pk.flatten(-2).permute(2,0,1)           # [1,256,850] --> [850,1,256]
pq = pq.unsqueeze(1).repeat(1,1,1)           # [100,1,256]
q = pq + cq
#------------------------------------------------------#
#   1) k = pk,则可视化: (cq + oq)*pk
#   2_ k = pk + memory,则可视化 (cq + oq)*(memory + pk)
#   读者可自行尝试
#------------------------------------------------------#
k = pk
# k = pk + memory
#------------------------------------------------------#

# 将q和k完成线性层的映射,代码参考自nn.MultiHeadAttn()
_b = in_proj_bias
_start = 0
_end = 256
_w = in_proj_weight[_start:_end, :]
if _b is not None:
    _b = _b[_start:_end]
q = linear(q, _w, _b)

_b = in_proj_bias
_start = 256
_end = 256 * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
    _b = _b[_start:_end]
k = linear(k, _w, _b)

scaling = float(256) ** -0.5
q = q * scaling
q = q.contiguous().view(100, 8, 32).transpose(0, 1)
k = k.contiguous().view(-1, 8, 32).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))

attn_output_weights = attn_output_weights.view(1, 8, 100, 850)
attn_output_weights = attn_output_weights.view(1 * 8, 100, 850)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = attn_output_weights.view(1, 8, 100, 850)

# 后续可视化各个头
attn_every_heads = attn_output_weights # [1,8,100,850]
attn_output_weights = attn_output_weights.sum(dim=1) / 8 # [1,100,850]

#-----------#
#   可视化
#-----------#
# get the feature map shape
h, w = conv_features['0'].tensors.shape[-2:]

fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=10, figsize=(22, 28))  # [11,2]
colors = COLORS * 100

# 可视化
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):
    # 可视化decoder的注意力权重
    ax = ax_i[0]
    ax.imshow(dec_attn_weights[0, idx].view(h, w))
    ax.axis('off')
    ax.set_title(f'query id: {idx.item()}',fontsize = 30)
    # 可视化框和类别
    ax = ax_i[1]
    ax.imshow(im)
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                               fill=False, color='blue', linewidth=3))
    ax.axis('off')
    ax.set_title(CLASSES[probas[idx].argmax()],fontsize = 30)
    # 分别可视化8个头部的位置特征图
    for head in range(2, 2 + 8):
        ax = ax_i[head]
        ax.imshow(attn_every_heads[0, head-2, idx].view(h,w))
        ax.axis('off')
        ax.set_title(f'head:{head-2}',fontsize = 30)
fig.tight_layout()        # 自动调整子图来使其填充整个画布
plt.show()

GitHub链接

https://github.com/wulele2/Detr-heat-map-visualization,给个star吧,太不容易了。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

附录:可视化resnet50最后一层特征图

'''
代码来源于facebook_detr
'''
#导入包
import requests
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)

# 获取一张图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)

plt.imshow(im)
plt.show()

# 构造图像变换
transform = T.Compose([
    T.Resize(800),              # 将图像进行Resize,符合短边变换原则
    T.ToTensor(),               # 将[0,255] --> [0,1]之间的张量
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 给每个通道进行归一化
])

# PLI --> tensor
img = transform(im).unsqueeze(0) # [1,3,h,w]

# 构造模型
model = resnet50(pretrained=True)

# 创建一个list存储特征图特征
fms = []
# 定制hook
def hook(module, input, output):
    fms.append(output)
# 注册hook
handle = model.layer4.register_forward_hook(hook)

#forward
model(img)

# 用完后删除hook
handle.remove()

# 可视化
plt.figure(figsize=(16, 10))   # 画布大小
ax = plt.gca()                 # 获取坐标轴
ax.imshow(fms[0].squeeze(0)[0])# 可视化第一个channel
ax.axis('off')                 # 关闭坐标轴
plt.show()                     # 展示

Detr空间注意力热图及语义注意力热图可视化_第2张图片
Detr空间注意力热图及语义注意力热图可视化_第3张图片

你可能感兴趣的:(mmcv和mmdet源码注释版,pytorch)