【论文精读】Learning Texture Transformer Network for Image Super-Resolution

目录

  • 出处
  • 贡献
  • 纹理转换器
    • 1.可学习纹理提取器(LTE)
    • 2.相关嵌入(RE)
    • 3.硬注意(HA)
    • 4.软注意(SA)
  • 跨尺度特征融合
  • 损失函数
    • 重建损失
    • 对抗性损失
    • 感知损失

出处

2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)

贡献

1.提出了一种包含四个密切相关模块的纹理转换器,用于图像超分辨率重建。
2.提出了一种新的跨尺度特征集成模块,用于图像生成任务,使我们的方法能够通过堆叠多个纹理转换器来学习更强大的特征表示。

纹理转换器

提出了用于图像超分辨率的纹理变换网络(TTSR)。在纹理转换器的基础上,提出了跨尺度特征集成模块(CSFI)来进一步提高模型的性能。
【论文精读】Learning Texture Transformer Network for Image Super-Resolution_第1张图片
纹理转换器的结构如上图所示, L R LR LR L R ↑ LR\uparrow LR R e f Ref Ref分别表示输入图像、4次双三次插值法对输入图像进行上采样和参考图像。并对 R e f Ref Ref进行4次双三次插值下采样和上采样,得到 R e f ↓ ↑ Ref\downarrow\uparrow Ref L R ↑ LR\uparrow LR结构保持一致。纹理转换器将 L R LR LR R e f Ref Ref R e f ↓ ↑ Ref\downarrow\uparrow Ref L R ↑ LR\uparrow LR作为输入,输出合成的特征图,进一步用于生成 H R HR HR预测。纹理转换器包括4部分:可学习纹理提取器(LTE)、相关性嵌入模块(RE)、特征转移的硬注意模块(HA)和特征合成的软注意模块(SA)。

1.可学习纹理提取器(LTE)

该部分设计了一个可学习的纹理提取器,它的参数会在端到端的训练过程中进行更新。这样的设计对 L R LR LR R e f Ref Ref图像进行联合特征学习,其中可以捕获更精确的纹理特征。纹理提取过程可以表示为:
Q = L T E ( L R ↑ ) (1) Q=LTE(LR\uparrow)\tag{1} Q=LTE(LR)(1)
K = L T E ( R e f ↓ ↑ ) (2) K=LTE(Ref\downarrow\uparrow)\tag{2} K=LTE(Ref)(2)
V = L T E ( R e f ) (3) V=LTE(Ref)\tag{3} V=LTE(Ref)(3)
其中 L T E ( ⋅ ) LTE(\cdot) LTE()表示我们的可学习纹理提取器的输出。提取的纹理特征Q、K和V表示转换器内部注意力机制的三个基本元素,并将进一步用于相关性嵌入模块。

2.相关嵌入(RE)

相关性嵌入的目的使通过估计Q和K之间的相似度来嵌入 L R LR LR R e f Ref Ref图像之间的相关性。将Q和K都展开成块, q i ( i ∈ [ 1 , H L R × W L R ] ) q_i(i\in[1,H_{LR}\times W_{LR}]) qi(i[1,HLR×WLR]) k j ( j ∈ [ 1 , H R e f × W R e f ] ) k_j(j\in[1,H_Ref\times W_{Ref}]) kj(j[1,HRef×WRef])。然后对每一个块 q i q_i qi k j k_j kj,通过归一化内积来计算两个块之间的相关性 r i , j r_{i,j} ri,j
r i , j = < q i ∣ ∣ q i ∣ ∣ , k j ∣ ∣ k j ∣ ∣ > (4) r_{i,j}=<\frac{q_i}{||q_i||},\frac{k_j}{||k_j||}>\tag{4} ri,j=<qiqi,kjkj>(4)
利用相关性得到硬注意图和软注意图。

3.硬注意(HA)

这里提出了一个硬注意模块来转移参考图像中的HR纹理特征V,在硬注意力模块中对于每个查询 q i q_i qi只从V中最相关的位置转移特征。
具体来说,首先计算硬关注图H,其中第i个元素 h i ( i ∈ [ 1 , H L R × W L R ] ) h_i(i\in[1,H_{LR}\times W_{LR}]) hi(i[1,HLR×WLR])是由相关性 r i , j r_{i,j} ri,j得来。
h i = a r g m a x j   r i , j (5) h_i=\underset{j}{argmax}\ r_{i,j}\tag{5} hi=jargmax ri,j(5)
h i h_i hi的值为硬索引,表示参考图像和LR图像中第i个最相关位置。为了从参考图像中获取纹理特征T,使用硬注意图作为索引,对V的展开块进行索引选择。
t i = v h i (6) t_i=v_{h_i}\tag{6} ti=vhi(6)
根据该输出结果,我们得到了LR图像和HR图像的特征表示T,它将进一步用于软注意模块。

4.软注意(SA)

提出的软注意模块将从DNN主干的LR图像和HR纹理特征T和LR特征F中合成特征。在合成过程中,为加强相关的纹理传递,从 r i , j r_{i,j} ri,j中计算软注意图S,以表示T中每个位置的所传送纹理特征的置信度。
s i = m a x j   r i , j (7) s_i=\underset{j}{max}\ r_{i,j}\tag{7} si=jmax ri,j(7)
其中侧重软注意图S的第i个位置,而不是直接将关注图S应用到T。首先融合HR纹理特征T和LR特征F,以利用来自LR图像的更多信息。融合后的特征被进一步以元素方式乘以软注意图S,并加回F,以获得纹理转换器的最终输出。
F o u t = F + C o n v ( C o n c a t ( F , T ) ) ⊙ S (8) F_{out}=F+Conv(Concat(F,T))\odot S\tag{8} Fout=F+Conv(Concat(F,T))S(8)
其中, F o u t F_{out} Fout表示合成的输出特征, C o n v Conv Conv C o n c a t Concat Concat分别表示卷积层和拼接操作, ⊙ \odot 表示元素相乘。
综上所述,纹理转换器可以有效地将参考图像中的相关HR纹理特征转换为LR特征,从而提高纹理生成的精确度。

跨尺度特征融合

【论文精读】Learning Texture Transformer Network for Image Super-Resolution_第2张图片
纹理转换器可以通过跨尺度融合模块以跨尺度的方式进一步堆叠。该体系结构如上图所示,3个纹理转换器输出3个分辨率不同的纹理特征融合LR图像中。
提出的跨尺度的特征融合模块(CSFI)在不同尺度的特征之间交换信息,每次将LR特征向上采样到下一个比例时,都会应用CSFI模块。对于CSFI模块中的每个尺度,通过上/下采样接受来自其他尺度的交换特征,然后在通道维度中进行串联操作。接着,卷积层将特征映射到原始数量的通道中。在这样的设计中,从堆叠的纹理转换器传输的纹理特征在每个尺度上交换,这实现了更强大的特征表示。这种跨尺度的特征集成模块进一步提高了方法的性能。

损失函数

总体的损失表示为:
L o v e r a l l = λ r e c L r e c + λ a d v L a d v + λ p e r L p e r (9) \mathcal{L}_{overall}=\lambda_{rec}\mathcal{L}_{rec}+\lambda_{adv}\mathcal{L}_{adv}+\lambda_{per}\mathcal{L}_{per}\tag{9} Loverall=λrecLrec+λadvLadv+λperLper(9)

重建损失

L r e c = 1 C H W ∣ ∣ I H R − I S R ∣ ∣ 1 (10) \mathcal{L}_{rec}=\frac{1}{CHW}||I^{HR}-I^{SR}||_1\tag{10} Lrec=CHW1IHRISR1(10)

对抗性损失

采用了WGAN-GP,它提出的梯度范数的惩罚来代替权重裁剪,使得训练更稳定,性能更好。
L D = E x ~ ∽ P g [ D ( x ~ ) ] − E x ∽ P r [ D ( x ) ] + λ E x ^ ∽ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] (11) \mathcal{L}_D=\underset{\tilde x\backsim \mathbb{P}_g}{\mathbb{E}}[D(\tilde x)]-\underset{x\backsim \mathbb{P}_r}{\mathbb{E}}[D(x)]+\lambda \underset{\hat x\backsim\mathbb{P}_{\hat x}}{\mathbb{E}}[(||\nabla_{\hat x}D(\hat x)||_2-1)^2]\tag{11} LD=x~PgE[D(x~)]xPrE[D(x)]+λx^Px^E[(x^D(x^)21)2](11)
L G = − E x ~ ∽ P g [ D ( x ~ ) ] (12) \mathcal{L}_G=-\underset{\tilde x\backsim\mathbb{P}_g}{\mathbb{E}}[D(\tilde x)]\tag{12} LG=x~PgE[D(x~)](12)

感知损失

L p e r = 1 C i H i W i ∣ ∣ ϕ i v g g ( I S R ) − ϕ i v g g ( I H R ) ∣ ∣ 2 2 + 1 C j H j W j ∣ ∣ ϕ j l t e ( I S R ) − T ∣ ∣ 2 2 (13) \mathcal{L}_{per}=\frac{1}{C_iH_iW_i}||\phi_i^{vgg}(I^{SR})-\phi_i^{vgg}(I^{HR})||_2^2+\frac{1}{C_jH_jW_j}||\phi_j^{lte}(I^{SR})-T||_2^2\tag{13} Lper=CiHiWi1ϕivgg(ISR)ϕivgg(IHR)22+CjHjWj1ϕjlte(ISR)T22(13)
第一部分是传统的感知损失,其中 ϕ ( ⋅ ) i v g g \phi(\cdot)_i^{vgg} ϕ()ivgg表示第i层vgg19的特征图。
第二部分是传递性感知损失,其中 L T E ( ⋅ ) LTE(\cdot) LTE()表示 ϕ \phi ϕ第j层提取的纹理特征图。

你可能感兴趣的:(论文精读,深度学习,神经网络,pytorch,机器学习)