首先出现KEYERROR:"TopDown:HRT is not in the models registry"
之后开启了我的找错误之旅
1、重新建立环境
pytorch==1.7.1 cuda=10.1
1)mmpose的建立
pip install mmpose==0.25.1
2)mmcv-full
pip install openmim
mim install mmcv-full==1.3.8
3)出现KEYERROR:"TopDown:HRT is not in the models registry"错误
某些当前代码使用的方法没有注册到现有的包中, 导致在import的时候无法导入该方法。
在该工程的根目录下:
requirements.txt
或者相关安装依赖环境的txt文件,重复安装一边;首先创建环境
pip install -r requirements.txt
其次
pip install -v -e .
如果移植自己的环境,仍然需要重新安装一遍以上环境。可以不安装requirements.txt这个环境。
关键代码解析:
# --------------------------------------------------------
# High Resolution Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Rao Fu, RainbowSecret
# --------------------------------------------------------
import os
import math
import logging
import torch
import torch.nn as nn
from functools import partial
from mmcv.cnn import build_conv_layer, build_norm_layer
BN_MOMENTUM = 0.1
# --------------------------------------------------------
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by Lang Huang, RainbowSecret from:
# https://github.com/openseg-group/openseg.pytorch/blob/master/lib/models/modules/isa_block.py
# --------------------------------------------------------
import os
import pdb
import math
import torch
import torch.nn as nn
# --------------------------------------------------------
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by Lang Huang, RainbowSecret from:
# https://github.com/openseg-group/openseg.pytorch/blob/master/lib/models/modules/isa_block.py
# --------------------------------------------------------
import copy
import math
import warnings
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch._jit_internal import Optional, Tuple
from torch.overrides import has_torch_function, handle_torch_function
from torch.nn.functional import linear, pad, softmax, dropout
from einops import rearrange
from timm.models.layers import to_2tuple, trunc_normal_
# --------------------------------------------------------
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by RainbowSecret from:
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L852
# --------------------------------------------------------
import copy
import math
import warnings
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.modules.module import Module
from torch._jit_internal import Optional, Tuple
from torch.overrides import has_torch_function, handle_torch_function
from torch.nn.functional import linear, pad, softmax, dropout
class MultiheadAttention(Module):
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.in_proj_bias = None
self.in_proj_weight = None
self.bias_k = self.bias_v = None
self.q_proj_weight = None
self.k_proj_weight = None
self.v_proj_weight = None
self.add_zero_attn = add_zero_attn
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super(MultiheadAttention, self).__setstate__(state)
def forward(
self,
query,
key,
value,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
residual_attn=None,
):
if not self._qkv_same_embed_dim:
return self.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
out_dim=self.vdim,
residual_attn=residual_attn,
)
else:
return self.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
out_dim=self.vdim,
residual_attn=residual_attn,
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
out_dim: Optional[Tensor] = None,
residual_attn: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
if not torch.jit.is_scripting():
tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(
tens_ops
):
return handle_torch_function(
multi_head_attention_forward,
tens_ops,
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training=training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight,
k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
)
tgt_len, bsz, embed_dim = query.size()
key = query if key is None else key
value = query if value is None else value
assert embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
v_head_dim = out_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
q = self.q_proj(query) * scaling
k = self.k_proj(key)
v = self.v_proj(value)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat(
[
k,
torch.zeros(
(k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
),
],
dim=1,
)
v = torch.cat(
[
v,
torch.zeros(
(v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
),
],
dim=1,
)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
"""
Attention weight for the invalid region is -inf
"""
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
if residual_attn is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights += residual_attn.unsqueeze(0)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
"""
Reweight the attention map before softmax().
attn_output_weights: (b*n_head, n, hw)
"""
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output
class MHA_(MultiheadAttention):
""" "Multihead Attention with extra flags on the q/k/v and out projections."""
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(self, *args, rpe=False, window_size=7, **kwargs):
super(MHA_, self).__init__(*args, **kwargs)
self.rpe = rpe
if rpe:
self.window_size = [window_size] * 2
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
self.num_heads,
)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(
self,
query,
key,
value,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
do_qkv_proj=True,
do_out_proj=True,
rpe=True,
):
if not self._qkv_same_embed_dim:
return self.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
out_dim=self.vdim,
do_qkv_proj=do_qkv_proj,
do_out_proj=do_out_proj,
rpe=rpe,
)
else:
return self.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
out_dim=self.vdim,
do_qkv_proj=do_qkv_proj,
do_out_proj=do_out_proj,
rpe=rpe,
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
out_dim: Optional[Tensor] = None,
do_qkv_proj: bool = True,
do_out_proj: bool = True,
rpe=True,
) -> Tuple[Tensor, Optional[Tensor]]:
if not torch.jit.is_scripting():
tens_ops = (
query,
key,
value,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
out_proj_weight,
out_proj_bias,
)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(
tens_ops
):
return handle_torch_function(
multi_head_attention_forward,
tens_ops,
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training=training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight,
k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight,
static_k=static_k,
static_v=static_v,
)
tgt_len, bsz, embed_dim = query.size()
key = query if key is None else key
value = query if value is None else value
assert embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
v_head_dim = out_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
# whether or not use the original query/key/value
q = self.q_proj(query) * scaling if do_qkv_proj else query
k = self.k_proj(key) if do_qkv_proj else key
v = self.v_proj(value) if do_qkv_proj else value
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if add_zero_attn:
src_len += 1
k = torch.cat(
[
k,
torch.zeros(
(k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
),
],
dim=1,
)
v = torch.cat(
[
v,
torch.zeros(
(v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
),
],
dim=1,
)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
"""
Add relative position embedding
"""
if self.rpe and rpe:
# NOTE: for simplicity, we assume the src_len == tgt_len == window_size**2 here
# print('src, tar, window', src_len, tgt_len, self.window_size[0], self.window_size[1])
# assert src_len == self.window_size[0] * self.window_size[1] \
# and tgt_len == self.window_size[0] * self.window_size[1], \
# f"src{src_len}, tgt{tgt_len}, window{self.window_size[0]}"
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
# HELLO!!!!!
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
) # + relative_position_bias.unsqueeze(0)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
"""
Attention weight for the invalid region is -inf
"""
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
"""
Reweight the attention map before softmax().
attn_output_weights: (b*n_head, n, hw)
"""
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
)
if do_out_proj:
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, q, k, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, q, k # additionaly return the query and key
#这里完成的是当featuremap尺寸不够除,需要在featuremap加入padding
class PadBlock(object):
""" "Make the size of feature map divisible by local group size."""
def __init__(self, local_group_size=7):
self.lgs = local_group_size#7
if not isinstance(self.lgs, (tuple, list)):
self.lgs = to_2tuple(self.lgs)#元组(7,7)
assert len(self.lgs) == 2
def pad_if_needed(self, x, size):
n, h, w, c = size
pad_h = math.ceil(h / self.lgs[0]) * self.lgs[0] - h
pad_w = math.ceil(w / self.lgs[1]) * self.lgs[1] - w
if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes
return F.pad(
x,
(0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
)#填充顺序在x的c,w,h维度进行填充
return x
def depad_if_needed(self, x, size):
n, h, w, c = size
pad_h = math.ceil(h / self.lgs[0]) * self.lgs[0] - h #按照我们假设的输入,pad_h=1
pad_w = math.ceil(w / self.lgs[1]) * self.lgs[1] - w #pad_w=6
if pad_h > 0 or pad_w > 0: # remove the center-padding on feature
return x[:, pad_h // 2 : pad_h // 2 + h, pad_w // 2 : pad_w // 2 + w, :]#x在h,w维度包含padding的值,为此需要删除掉没有padding
return x
class LocalPermuteModule(object):
""" "Permute the feature map to gather pixels in local groups, and the reverse permutation"""
def __init__(self, local_group_size=7):
self.lgs = local_group_size
if not isinstance(self.lgs, (tuple, list)):
self.lgs = to_2tuple(self.lgs)
assert len(self.lgs) == 2
#完成进入atten时候的输入形式(B,N,C)
def permute(self, x, size):
n, h, w, c = size
return rearrange(
x,
"n (qh ph) (qw pw) c -> (ph pw) (n qh qw) c",
n=n,
qh=h // self.lgs[0],
ph=self.lgs[0],
qw=w // self.lgs[0],
pw=self.lgs[0],
c=c,
)
#将atten时候的输入形式(B,N,C)变成卷积的特征图谱维度
def rev_permute(self, x, size):
n, h, w, c = size
return rearrange(
x,
"(ph pw) (n qh qw) c -> n (qh ph) (qw pw) c",
n=n,
qh=h // self.lgs[0],
ph=self.lgs[0],
qw=w // self.lgs[0],
pw=self.lgs[0],
c=c,
)
class InterlacedPoolAttention(nn.Module):
r"""interlaced sparse multi-head self attention (ISA) module with relative position bias.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): Window size.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, embed_dim, num_heads, window_size=7, rpe=True, **kwargs):
super(InterlacedPoolAttention, self).__init__()
self.dim = embed_dim
self.num_heads = num_heads
self.window_size = window_size
self.with_rpe = rpe
self.attn = MHA_(
embed_dim, num_heads, rpe=rpe, window_size=window_size, **kwargs
)
self.pad_helper = PadBlock(window_size)
self.permute_helper = LocalPermuteModule(window_size)
def forward(self, x, H, W, **kwargs):
B, N, C = x.shape
x = x.view(B, H, W, C)
print('x', x.shape)#x torch.Size([78, 48, 64, 78])
# attention
# pad
x_pad = self.pad_helper.pad_if_needed(x, x.size())
print('x_pad', x_pad.shape)#x_pad torch.Size([78, 49, 70, 78])
# permute
x_permute = self.permute_helper.permute(x_pad, x_pad.size())
print('x_permute', x_permute.shape) # x_permute torch.Size([49, 5460, 78])
# attention
out, _, _ = self.attn(
x_permute, x_permute, x_permute, rpe=self.with_rpe, **kwargs
)
print('out', out.shape)#out torch.Size([49, 5460, 78])
# reverse permutation
out = self.permute_helper.rev_permute(out, x_pad.size())
print('out1', out.shape)#out1 torch.Size([78, 49, 70, 78])
# de-pad, pooling with `ceil_mode=True` will do implicit padding, so we need to remove it, too
out = self.pad_helper.depad_if_needed(out, x.size())
# print('out.reshape(B, N, C)',out.reshape(B, N, C).shape)#out.reshape(B, N, C) torch.Size([78, 3072, 78])
return out.reshape(B, N, C)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return "drop_prob={}".format(self.drop_prob)
class MlpDWBN(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
dw_act_layer=nn.GELU,
drop=0.0,
conv_cfg=None,
norm_cfg=dict(type="BN", requires_grad=True),
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = build_conv_layer(
conv_cfg,
in_features,
hidden_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.act1 = act_layer()
self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
self.dw3x3 = build_conv_layer(
conv_cfg,
hidden_features,
hidden_features,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_features,
)
self.act2 = dw_act_layer()
self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1]
self.fc2 = build_conv_layer(
conv_cfg,
hidden_features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
self.act3 = act_layer()
self.norm3 = build_norm_layer(norm_cfg, out_features)[1]
# self.drop = nn.Dropout(drop, inplace=True)
def forward(self, x, H, W):
if len(x.shape) == 3:
B, N, C = x.shape
if N == (H * W + 1):
cls_tokens = x[:, 0, :]
x_ = x[:, 1:, :].permute(0, 2, 1).contiguous().reshape(B, C, H, W)
else:
x_ = x.permute(0, 2, 1).contiguous().reshape(B, C, H, W)
x_ = self.fc1(x_)
x_ = self.norm1(x_)
x_ = self.act1(x_)
x_ = self.dw3x3(x_)
x_ = self.norm2(x_)
x_ = self.act2(x_)
# x_ = self.drop(x_)
x_ = self.fc2(x_)
x_ = self.norm3(x_)
x_ = self.act3(x_)
# x_ = self.drop(x_)
x_ = x_.reshape(B, C, -1).permute(0, 2, 1).contiguous()
if N == (H * W + 1):
x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1)
else:
x = x_
return x
elif len(x.shape) == 4:
x = self.fc1(x)
x = self.norm1(x)
x = self.act1(x)
x = self.dw3x3(x)
x = self.norm2(x)
x = self.act2(x)
x = self.drop(x)
x = self.fc2(x)
x = self.norm3(x)
x = self.act3(x)
x = self.drop(x)
return x
else:
raise RuntimeError("Unsupported input shape: {}".format(x.shape))
class GeneralTransformerBlock(nn.Module):
expansion = 1
def __init__(
self,
inplanes,
planes,
num_heads,
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
conv_cfg=None,
norm_cfg=dict(type="BN", requires_grad=True),
):
super().__init__()
self.dim = inplanes
self.out_dim = planes
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.attn = InterlacedPoolAttention(
self.dim, num_heads=num_heads, window_size=window_size, dropout=attn_drop
)
self.norm1 = norm_layer(self.dim)
self.norm2 = norm_layer(self.out_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
mlp_hidden_dim = int(self.dim * mlp_ratio)
self.mlp = MlpDWBN(
in_features=self.dim,
hidden_features=mlp_hidden_dim,
out_features=self.out_dim,
act_layer=act_layer,
dw_act_layer=act_layer,
drop=drop,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
def forward(self, x):
B, C, H, W = x.size()
# reshape
x = x.view(B, C, -1).permute(0, 2, 1).contiguous()
# Attention
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
# reshape
x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
return x
a = torch.randn(78,78,48, 64)
b = GeneralTransformerBlock(78,78,3)
c = b(a)
print('c',c.shape)