之前一直比较好奇Conditional Detr中如何可视化各个头部的空间注意力热图的,于是,本人尝试在Detr基础上实现了一个demo,可以无脑化运行,先上最终的效果图:
代码中我已经加了详细注释,文末有GitHub链接。
# #------------------------------------------------------------#
# 可视化Detr方法:
# spatial attention weight : (cq + oq)*pk
# combined attention weight: (cq + oq)*(memory + pk)
# 其中:
# pk:原始特征图的位置编码;
# oq:训练好的object queries
# cq:decoder最后一层self-attn中的输出query
# memory:encoder的输出
# #------------------------------------------------------------#
# 在此基础上只要稍微修改便可可视化ConditionalDetr的Fig1特征图
# #------------------------------------------------------------#
# 代码参考自:https://github.com/facebookresearch/detr/tree/colab
# #------------------------------------------------------------#
import math
import numpy as np
from PIL import Image
import requests
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
from torch.nn.functional import dropout,linear,softmax
torch.set_grad_enabled(False)
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
# COCO classes
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
# standard PyTorch mean-std input image normalization
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载线上的模型
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()
# 获取训练好的参数
for name, parameters in model.named_parameters():
# 获取训练好的object queries,即pq:[100,256]
if name == 'query_embed.weight':
pq = parameters
# 获取解码器的最后一层的交叉注意力模块中q和k的线性权重和偏置:[256*3,256],[768]
if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_weight':
in_proj_weight = parameters
if name == 'transformer.decoder.layers.5.multihead_attn.in_proj_bias':
in_proj_bias = parameters
# 线上下载图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
# img_path = '/home/wujian/000000039769.jpg'
# im = Image.open(img_path)
# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)
# propagate through the model
outputs = model(img)
# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
# use lists to store the outputs via up-values
conv_features, enc_attn_weights, dec_attn_weights = [], [], []
cq = [] # 存储detr中的 cq
pk = [] # 存储detr中的 encoder pos
memory = [] # 存储encoder的输出特征图memory
# 注册hook
hooks = [
# 获取resnet最后一层特征图
model.backbone[-2].register_forward_hook(
lambda self, input, output: conv_features.append(output)
),
# 获取encoder的图像特征图memory
model.transformer.encoder.register_forward_hook(
lambda self, input, output: memory.append(output)
),
# 获取encoder的最后一层layer的self-attn weights
model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
lambda self, input, output: enc_attn_weights.append(output[1])
),
# 获取decoder的最后一层layer中交叉注意力的 weights
model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
lambda self, input, output: dec_attn_weights.append(output[1])
),
# 获取decoder最后一层self-attn的输出cq
model.transformer.decoder.layers[-1].norm1.register_forward_hook(
lambda self, input, output: cq.append(output)
),
# 获取图像特征图的位置编码pk
model.backbone[-1].register_forward_hook(
lambda self, input, output: pk.append(output)
),
]
# propagate through the model
outputs = model(img)
# 用完的hook后删除
for hook in hooks:
hook.remove()
# don't need the list anymore
conv_features = conv_features[0] # [1,2048,25,34]
enc_attn_weights = enc_attn_weights[0] # [1,850,850] : [N,L,S]
dec_attn_weights = dec_attn_weights[0] # [1,100,850] : [N,L,S] --> [batch, tgt_len, src_len]
memory = memory[0] # [850,1,256]
cq = cq[0] # decoder的self_attn:最后一层输出[100,1,256]
pk = pk[0] # [1,256,25,34]
# 绘制postion embedding
pk = pk.flatten(-2).permute(2,0,1) # [1,256,850] --> [850,1,256]
pq = pq.unsqueeze(1).repeat(1,1,1) # [100,1,256]
q = pq + cq
#------------------------------------------------------#
# 1) k = pk,则可视化: (cq + oq)*pk
# 2_ k = pk + memory,则可视化 (cq + oq)*(memory + pk)
# 读者可自行尝试
#------------------------------------------------------#
k = pk
# k = pk + memory
#------------------------------------------------------#
# 将q和k完成线性层的映射,代码参考自nn.MultiHeadAttn()
_b = in_proj_bias
_start = 0
_end = 256
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = linear(q, _w, _b)
_b = in_proj_bias
_start = 256
_end = 256 * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = linear(k, _w, _b)
scaling = float(256) ** -0.5
q = q * scaling
q = q.contiguous().view(100, 8, 32).transpose(0, 1)
k = k.contiguous().view(-1, 8, 32).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
attn_output_weights = attn_output_weights.view(1, 8, 100, 850)
attn_output_weights = attn_output_weights.view(1 * 8, 100, 850)
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = attn_output_weights.view(1, 8, 100, 850)
# 后续可视化各个头
attn_every_heads = attn_output_weights # [1,8,100,850]
attn_output_weights = attn_output_weights.sum(dim=1) / 8 # [1,100,850]
#-----------#
# 可视化
#-----------#
# get the feature map shape
h, w = conv_features['0'].tensors.shape[-2:]
fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=10, figsize=(22, 28)) # [11,2]
colors = COLORS * 100
# 可视化
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):
# 可视化decoder的注意力权重
ax = ax_i[0]
ax.imshow(dec_attn_weights[0, idx].view(h, w))
ax.axis('off')
ax.set_title(f'query id: {idx.item()}',fontsize = 30)
# 可视化框和类别
ax = ax_i[1]
ax.imshow(im)
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color='blue', linewidth=3))
ax.axis('off')
ax.set_title(CLASSES[probas[idx].argmax()],fontsize = 30)
# 分别可视化8个头部的位置特征图
for head in range(2, 2 + 8):
ax = ax_i[head]
ax.imshow(attn_every_heads[0, head-2, idx].view(h,w))
ax.axis('off')
ax.set_title(f'head:{head-2}',fontsize = 30)
fig.tight_layout() # 自动调整子图来使其填充整个画布
plt.show()
https://github.com/wulele2/Detr-heat-map-visualization,给个star吧,太不容易了。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。
'''
代码来源于facebook_detr
'''
#导入包
import requests
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)
# 获取一张图像
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
plt.imshow(im)
plt.show()
# 构造图像变换
transform = T.Compose([
T.Resize(800), # 将图像进行Resize,符合短边变换原则
T.ToTensor(), # 将[0,255] --> [0,1]之间的张量
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 给每个通道进行归一化
])
# PLI --> tensor
img = transform(im).unsqueeze(0) # [1,3,h,w]
# 构造模型
model = resnet50(pretrained=True)
# 创建一个list存储特征图特征
fms = []
# 定制hook
def hook(module, input, output):
fms.append(output)
# 注册hook
handle = model.layer4.register_forward_hook(hook)
#forward
model(img)
# 用完后删除hook
handle.remove()
# 可视化
plt.figure(figsize=(16, 10)) # 画布大小
ax = plt.gca() # 获取坐标轴
ax.imshow(fms[0].squeeze(0)[0])# 可视化第一个channel
ax.axis('off') # 关闭坐标轴
plt.show() # 展示