本节学习 Transformer 嵌入 Transformer 的融合网络 TNT,思想自然,源于华为,值得一看。
Transformer in Transformer(TNT) 是华为团队 2021 的文章,论文为 Transformer in Transformer,并且表达的思想非常简单和自然。在原始的 Transformer 中,例如 ViT 中,对图片进行切片处理,切成了 16 × 16 16 \times 16 16×16 的 patch,Multi-Head Self-Attention 编码的是 patch 与 patch 之间的信息,也就是例如一个人脸图片,编码眼睛和嘴巴之间的相应权重。但是对于 patch 内部却没有进行很精细的编码,也就是说图片输入本身的“局部性”特性没有很好的利用上。一个很自然的想法就是,那我把 patch 直接变小好了,例如变为 4 × 4 4 \times 4 4×4 的大小,但是这样的话大大增加了 token 的数量,也增加了 Transformer 模型的计算量。此外之前微软团队的 Swin 工作也提到说,考虑把 key 值共享在一个小的 location 里可以对应于图像的局部性。那么,我们可以将大的 patch 再进行切片,在 patch 内部做一个 Multi-Head Self-Attention,将 key 值共享在 patch 内部,这样就能够兼顾“局部性”以及计算量。
事实上,Transformer in Transformer 和 U2Net (Unet in Unet) 有异曲同工之妙!
patch 内部的信息交流也非常重要,特别是对于精细结构而言。作者将 16 × 16 16 \times 16 16×16 的 patch 称为 visual sentences,将 patch 再进行切片出来的例如 4 × 4 4 \times 4 4×4 的更小的 patch 称为 visual word。这样 patch 内部进行自注意力机制研究哪里比较重要,patch 之间通过自注意力机制研究彼此之间的关系,能帮助我们增强特征的表示能力。此外,这样的小范围的 transformer 并不会增加太大的计算量。
让我们一步一步看看它是怎么实现的。
首先给定一个 2D 图像,我们通常将其进行切片形成 X = [ X 1 , X 2 , ⋯ , X n ] ∈ R n × p × p × 3 \mathcal{X}=\left[X^{1}, X^{2}, \cdots, X^{n}\right] \in \mathbb{R}^{n \times p \times p \times 3} X=[X1,X2,⋯,Xn]∈Rn×p×p×3, 即切成 n n n 个 patch,每个 patch 是 p × p × 3 p \times p \times 3 p×p×3,3 为 RGB 通道, p p p 为 patch 的分辨率。此后我们对单个 patch 再进行切分,切成 m m m 个 sub-patch,得到 visual sentence 对应的一系列的 visual words:
X i → [ x i , 1 , x i , 2 , ⋯ , x i , m ] X^{i} \rightarrow\left[x^{i, 1}, x^{i, 2}, \cdots, x^{i, m}\right] Xi→[xi,1,xi,2,⋯,xi,m]
其中 x i , j ∈ R s × s × 3 x^{i,j} \in \mathbb{R}^{s \times s \times 3} xi,j∈Rs×s×3,3 为 RGB 通道, s s s 为 sub-patch 的分辨率, x i , j x^{i,j} xi,j 表示第 i i i 个 patch 的第 j j j 个 sub-patch。例如一个 16 × 16 16 \times 16 16×16 的 patch 可以再被切为 16 个 4 × 4 4 \times 4 4×4 的 patch,则 m = 16 , s = 4 m=16, s=4 m=16,s=4。对于每个 sub-patch,首先将它们进行展平(向量化 vectorization),然后经过一个全连接层进行编码:
Y i = [ y i , 1 , y i , 2 , ⋯ , y i , m ] , y i , j = F C ( Vec ( x i , j ) ) Y^{i}=\left[y^{i, 1}, y^{i, 2}, \cdots, y^{i, m}\right], \quad y^{i, j}=F C\left(\operatorname{Vec}\left(x^{i, j}\right)\right) Yi=[yi,1,yi,2,⋯,yi,m],yi,j=FC(Vec(xi,j))
然后将它们经过一个非常普通的由 Multi-Head Self-Attention 和 MLP 组成的 Transformer block 进行 patch 内部的 Transformer 操作,LN 则是 Layer Normalization 层,其中下标 l l l 表示第几个 inner transform block:
Y l ′ i = Y l − 1 i + MSA ( L N ( Y l − 1 i ) ) Y l i = Y l ′ i + MLP ( L N ( Y l ′ i ) ) \begin{aligned} Y_{l}^{\prime i} &=Y_{l-1}^{i}+\operatorname{MSA}\left(L N\left(Y_{l-1}^{i}\right)\right) \\ Y_{l}^{i} &=Y_{l}^{\prime i}+ \operatorname{MLP}\left(L N\left(Y_{l}^{\prime i}\right)\right) \end{aligned} Yl′iYli=Yl−1i+MSA(LN(Yl−1i))=Yl′i+MLP(LN(Yl′i))
得到 patch 内部经过注意力机制变换后的新的 sub-patch 表示,对其进行编码得到和对应 patch 相同大小的编码,与原始 patch 进行加和。这里的 W l − 1 , b l − 1 W_{l-1},b_{l-1} Wl−1,bl−1 对应的就是一个全连接层的权重和偏置,但是在源码中这里并没有使用偏置 self.proj = nn.Linear(num_words * inner_dim, outer_dim, bias=False)
。
Z l − 1 i = Z l − 1 i + Vec ( Y l − 1 i ) W l − 1 + b l − 1 Z_{l-1}^{i}=Z_{l-1}^{i}+\operatorname{Vec}\left(Y_{l-1}^{i}\right) W_{l-1}+b_{l-1} Zl−1i=Zl−1i+Vec(Yl−1i)Wl−1+bl−1
patch token 结合了自身切片后进行变换后的信息在馈入 outer transformer block 就得到了下一层的新 token。
Z l ′ i = Z l − 1 i + MSA ( L N ( Z l − 1 i ) ) Z l i = Z ′ i + MLP ( L N ( Z l ′ i ) ) \begin{aligned} \mathcal{Z}_{l}^{\prime i} &=\mathcal{Z}_{l-1}^{i}+\operatorname{MSA}\left(L N\left(\mathcal{Z}_{l-1}^{i}\right)\right) \\ \mathcal{Z}_{l}^{i} &=\mathcal{Z}^{\prime i}+\operatorname{MLP}\left(L N\left(\mathcal{Z}_{l}^{\prime i}\right)\right) \end{aligned} Zl′iZli=Zl−1i+MSA(LN(Zl−1i))=Z′i+MLP(LN(Zl′i))
这些总结出来就是:
Y l , Z l = TNT ( Y l − 1 , Z l − 1 ) \mathcal{Y}_{l}, \mathcal{Z}_{l}=\operatorname{TNT}\left(\mathcal{Y}_{l-1}, \mathcal{Z}_{l-1}\right) Yl,Zl=TNT(Yl−1,Zl−1)
在 TNT block 中,inner transformer block 用于捕捉局部区域之间的特征关系,而 outer transformer block 则用于捕捉区域之间的联系。将 inner transformer block 的输出加到 outer transformer block 的输入上,这其实就像类似上级向下级传达任务(分类猫狗),下级首先形成一个初步意见(初始token),然后每个小组内部先进行讨论沟通交流,整合意见之后交给小组长(token 相加),小组长共同开会进行交流沟通得到一个意见之间的融合,再传递给上级。这种组内部的交流沟通在生活中非常常见。
In our TNT block, the inner transformer block is used to model the relationship between visual words for local feature extraction, and the outer transformer block captures the intrinsic information from the sequence of sentences. By stacking the TNT blocks for L times, we build the transformerin-transformer network. Finally, the classification token serves as the image representation and a fully-connected layer is applied for classification.
TNT 网络结构如下所示,Depth 为 block 重复的次数。
很多工作表示了,位置编码非常重要,那么每个 sub-patch 对应的 token 怎么进行位置编码呢?在 TNT 中,不同 sentence 中的 words 共享位置编码,即 E w o r d E_{word} Eword 表示的是 sub-patch 在 patch 中的位置,这个在各个大 patch 中是一致的。
Z 0 ← Z 0 + E sentence Y 0 i ← Y 0 i + E w o r d , i = 1 , 2 , ⋯ , n \mathcal{Z}_{0} \leftarrow \mathcal{Z}_{0}+E_{\text {sentence }}\\ Y_{0}^{i} \leftarrow Y_{0}^{i}+E_{w o r d}, i=1,2, \cdots, n Z0←Z0+Esentence Y0i←Y0i+Eword,i=1,2,⋯,n
实验也发现,两个都使用位置编码可以得到更好的效果,但是在这篇 paper 中就没有去讨论使用绝对还是相对位置编码,1D 还是 2D 位置编码,以及把位置编码和 query 点乘构建 Attention 的一部分了。
对于原论文的描述,我一直存在疑惑,希望得到大家的解答。
一个标准的 Transformer block 包含两部分,Multi-head Self-Attention 和 multi-layer perceptron. 其中 Multi-head Self-Attention 的计算 FLOPs 为 2 n d ( d k + d v ) + n 2 ( d k + d v ) 2 n d\left(d_{k}+d_{v}\right)+n^{2}\left(d_{k}+d_{v}\right) 2nd(dk+dv)+n2(dk+dv),其中 n n n 为 token 的数量, d d d 为 token 的维度,query 和 key 的维度都是 d k d_k dk,value 的维度为 d v d_v dv。MLP 对应的 FLOPs 则为 2 n d v r d v 2nd_vrd_v 2ndvrdv,其中 r r r 通常取 4,为第一个 全连接层维度扩展的倍数。通常呢这个 d k = d v = d d_k = d_v = d dk=dv=d 。当不考虑 bias 以及 LN 时,一个标准的 transformer block 的 FLOPs 为 :
FLOPs T = 2 n d ( d k + d v ) + n 2 ( d k + d v ) + 2 n d d r = 2 n d ( 6 d + n ) \operatorname{FLOPs}_{T}=2 n d\left(d_{k}+d_{v}\right)+n^{2}\left(d_{k}+d_{v}\right)+2 n d d r = 2nd(6d+n) FLOPsT=2nd(dk+dv)+n2(dk+dv)+2nddr=2nd(6d+n)
MLP 对应的 FLOPs 很好理解,一个 n d d r nddr nddr 是第一个全连接层从 d d d 扩展维度到 r d rd rd;一个 n d d r nddr nddr 是第一个全连接层从 r d rd rd 还原到 d d d。但是 MSA 的 FLOPs 就不是那么好理解了。 2 n d d k 2ndd_k 2nddk 是计算 n n n 个 token 对应的 key 和 query, n d d v ndd_v nddv 是计算 n n n 个 token 对应的 value,那前面的 2 怎么理解呢?此外, n 2 ( d k + d v ) n^{2}\left(d_{k}+d_{v}\right) n2(dk+dv) 应该对应的是计算 Attention 矩阵以及对 d v d_v dv 进行加权求和。计算 Attention 是会得到 n 2 n^2 n2 对 key 和 query 点乘,也应该是 n 2 d k n^2 d_k n2dk? n n n 个新 token,每个源自 n n n 个 value 的加权求和,那么很自然是 n 2 d v n^2d_v n2dv。其中 n d v n d_v ndv 表示一个新 token 是 n n n 个 value 的加权求和。
我认为的一个标准的 transformer block 的 FLOPs 为 :
一个标准的 transformer block 的参数量为 :
Params T = 12 d d . \text { Params }_{T}=12 d d \text {. } Params T=12dd.
而在 TNT 中,新增的一个 inner transformer block 的计算量和一个全连接层,所以 TNT 的 FLOPs 为 :
F L O P S T N T = 2 n m c ( 6 c + m ) + n m c d + 2 n d ( 6 d + n ) \mathrm{FLOPS}_{T N T}=2 n m c(6 c+m)+n m c d+2 n d(6 d+n) FLOPSTNT=2nmc(6c+m)+nmcd+2nd(6d+n)
其中 2 n m c ( 6 c + m ) 2nmc(6c+m) 2nmc(6c+m) 为 inner transformer block, 2 n d ( 6 d + n ) 2nd(6d+n) 2nd(6d+n) 为 outer transformer block, n m c d nmcd nmcd 为那个全连接层,需要从 m × c m \times c m×c 维转为 n × d n \times d n×d 维。相应的 TNT 的参数量为:
Params T N T = 12 c c + m c d + 12 d d . \text { Params }_{TNT}=12cc+mcd+12 d d \text {. } Params TNT=12cc+mcd+12dd.
当 c ≪ d c \ll d c≪d 且 O ( m ) ≈ O ( n ) O(m) \approx O(n) O(m)≈O(n) 时候,例如 d = 384 , n = 196 , c = 24 , m = 16 d=384,n=196,c=24,m=16 d=384,n=196,c=24,m=16 时,TNT 的 FLOPs 为 1.14x,Params 为 1.08x。
但是 FLOPs 原论文算正确了吗?
TNT 做分类等实验都表现出了一定的性能,其中和 DeiT 对比的可视化结果比较可观的反映出来了提升。可见局部信息在浅层得到了很好的保留,而随着网络的深入,表征逐渐变得更加抽象。
inner transformer 中不同的 query 位置(sub-patch)对应的激活部位可见,局部交流还是很有用的!
代码出处见 此处。
# 2021.06.15-Changed for implementation of TNT model
# Huawei Technologies Co., Ltd.
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
from functools import partial
import math
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.resnet import resnet26d, resnet50d
from timm.models.registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'tnt_s_patch16_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'tnt_b_patch16_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
}
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SE(nn.Module):
def __init__(self, dim, hidden_ratio=None):
super().__init__()
hidden_ratio = hidden_ratio or 1
self.dim = dim
hidden_dim = int(dim * hidden_ratio)
self.fc = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, dim),
nn.Tanh()
)
def forward(self, x):
a = x.mean(dim=1, keepdim=True) # B, 1, C
a = self.fc(a)
x = a * x
return x
class Attention(nn.Module):
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
head_dim = hidden_dim // num_heads
self.head_dim = head_dim
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop, inplace=True)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop, inplace=True)
def forward(self, x):
B, N, C = x.shape
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
""" TNT Block
"""
def __init__(self, outer_dim, inner_dim, outer_num_heads, inner_num_heads, num_words, mlp_ratio=4.,
qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, se=0):
super().__init__()
self.has_inner = inner_dim > 0
if self.has_inner:
# Inner
self.inner_norm1 = norm_layer(inner_dim)
self.inner_attn = Attention(
inner_dim, inner_dim, num_heads=inner_num_heads, qkv_bias=qkv_bias,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.inner_norm2 = norm_layer(inner_dim)
self.inner_mlp = Mlp(in_features=inner_dim, hidden_features=int(inner_dim * mlp_ratio),
out_features=inner_dim, act_layer=act_layer, drop=drop)
self.proj_norm1 = norm_layer(num_words * inner_dim)
self.proj = nn.Linear(num_words * inner_dim, outer_dim, bias=False)
self.proj_norm2 = norm_layer(outer_dim)
# Outer
self.outer_norm1 = norm_layer(outer_dim)
self.outer_attn = Attention(
outer_dim, outer_dim, num_heads=outer_num_heads, qkv_bias=qkv_bias,
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.outer_norm2 = norm_layer(outer_dim)
self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio),
out_features=outer_dim, act_layer=act_layer, drop=drop)
# SE
self.se = se
self.se_layer = None
if self.se > 0:
self.se_layer = SE(outer_dim, 0.25)
def forward(self, inner_tokens, outer_tokens):
if self.has_inner:
inner_tokens = inner_tokens + self.drop_path(self.inner_attn(self.inner_norm1(inner_tokens))) # B*N, k*k, c
inner_tokens = inner_tokens + self.drop_path(self.inner_mlp(self.inner_norm2(inner_tokens))) # B*N, k*k, c
B, N, C = outer_tokens.size()
outer_tokens[:,1:] = outer_tokens[:,1:] + self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, N-1, -1)))) # B, N, C
if self.se > 0:
outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
tmp_ = self.outer_mlp(self.outer_norm2(outer_tokens))
outer_tokens = outer_tokens + self.drop_path(tmp_ + self.se_layer(tmp_))
else:
outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens)))
outer_tokens = outer_tokens + self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))
return inner_tokens, outer_tokens
class PatchEmbed(nn.Module):
""" Image to Visual Word Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, outer_dim=768, inner_dim=24, inner_stride=4):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.inner_dim = inner_dim
self.num_words = math.ceil(patch_size[0] / inner_stride) * math.ceil(patch_size[1] / inner_stride)
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
self.proj = nn.Conv2d(in_chans, inner_dim, kernel_size=7, padding=3, stride=inner_stride)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.unfold(x) # B, Ck2, N
x = x.transpose(1, 2).reshape(B * self.num_patches, C, *self.patch_size) # B*N, C, 16, 16
x = self.proj(x) # B*N, C, 8, 8
x = x.reshape(B * self.num_patches, self.inner_dim, -1).transpose(1, 2) # B*N, 8*8, C
return x
class TNT(nn.Module):
""" TNT (Transformer in Transformer) for computer vision
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, outer_dim=768, inner_dim=48,
depth=12, outer_num_heads=12, inner_num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, inner_stride=4, se=0):
super().__init__()
self.num_classes = num_classes
self.num_features = self.outer_dim = outer_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, outer_dim=outer_dim,
inner_dim=inner_dim, inner_stride=inner_stride)
self.num_patches = num_patches = self.patch_embed.num_patches
num_words = self.patch_embed.num_words
self.proj_norm1 = norm_layer(num_words * inner_dim)
self.proj = nn.Linear(num_words * inner_dim, outer_dim)
self.proj_norm2 = norm_layer(outer_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, outer_dim))
self.outer_tokens = nn.Parameter(torch.zeros(1, num_patches, outer_dim), requires_grad=False)
self.outer_pos = nn.Parameter(torch.zeros(1, num_patches + 1, outer_dim))
self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
vanilla_idxs = []
blocks = []
for i in range(depth):
if i in vanilla_idxs:
blocks.append(Block(
outer_dim=outer_dim, inner_dim=-1, outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads,
num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, se=se))
else:
blocks.append(Block(
outer_dim=outer_dim, inner_dim=inner_dim, outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads,
num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, se=se))
self.blocks = nn.ModuleList(blocks)
self.norm = norm_layer(outer_dim)
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
#self.repr = nn.Linear(outer_dim, representation_size)
#self.repr_act = nn.Tanh()
# Classifier head
self.head = nn.Linear(outer_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.outer_pos, std=.02)
trunc_normal_(self.inner_pos, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'outer_pos', 'inner_pos', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.outer_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C
outer_tokens = self.proj_norm2(self.proj(self.proj_norm1(inner_tokens.reshape(B, self.num_patches, -1))))
outer_tokens = torch.cat((self.cls_token.expand(B, -1, -1), outer_tokens), dim=1)
outer_tokens = outer_tokens + self.outer_pos
outer_tokens = self.pos_drop(outer_tokens)
for blk in self.blocks:
inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens)
outer_tokens = self.norm(outer_tokens)
return outer_tokens[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _conv_filter(state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
@register_model
def tnt_s_patch16_224(pretrained=False, **kwargs):
patch_size = 16
inner_stride = 4
outer_dim = 384
inner_dim = 24
outer_num_heads = 6
inner_num_heads = 4
outer_dim = make_divisible(outer_dim, outer_num_heads)
inner_dim = make_divisible(inner_dim, inner_num_heads)
model = TNT(img_size=224, patch_size=patch_size, outer_dim=outer_dim, inner_dim=inner_dim, depth=12,
outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads, qkv_bias=False,
inner_stride=inner_stride, **kwargs)
model.default_cfg = default_cfgs['tnt_s_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model
@register_model
def tnt_b_patch16_224(pretrained=False, **kwargs):
patch_size = 16
inner_stride = 4
outer_dim = 640
inner_dim = 40
outer_num_heads = 10
inner_num_heads = 4
outer_dim = make_divisible(outer_dim, outer_num_heads)
inner_dim = make_divisible(inner_dim, inner_num_heads)
model = TNT(img_size=224, patch_size=patch_size, outer_dim=outer_dim, inner_dim=inner_dim, depth=12,
outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads, qkv_bias=False,
inner_stride=inner_stride, **kwargs)
model.default_cfg = default_cfgs['tnt_b_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
return model