本文试图从原理和代码简单介绍低照度增强领域中比较新的一篇论文——Retinexformer,其效果不错,刷新了十三大暗光增强效果榜单。
❗论文名称:Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement
论文信息:由清华大学联合维尔兹堡大学和苏黎世联邦理工学院2023年8月发表在ICCV2023的一篇论文。
论文地址:https://arxiv.org/abs/2303.06705
代码地址:https://github.com/caiyuanhao1998/Retinexformer
部分参考来源:https://zhuanlan.zhihu.com/p/657927878
论文主要贡献总结如下:
1.提出了首个与Retinex理论相结合的 Transformer 算法,命名为 Retinexformer。
2.推导了一个单阶段Retinex理论框架,名为 ORF(One-stage Retinex-based Framework),只需要一个阶段端到端的训练即可,流程简单。
3.设计了一种由光照引导的新型多头自注意机制,名为 IG-MSA(Illumination-Guided Multi-head Self-Attention,IG-MSA),将光照信息作为关键线索来引导长程依赖关系的捕获。
class Illumination_Estimator(nn.Module):
def __init__(
self, n_fea_middle, n_fea_in=4, n_fea_out=3): #__init__部分是内部属性,而forward的输入才是外部输入
super(Illumination_Estimator, self).__init__()
self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)
self.depth_conv = nn.Conv2d(
n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)
self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)
def forward(self, img):
# img: b,c=3,h,w
# mean_c: b,c=1,h,w
# illu_fea: b,c,h,w
# illu_map: b,c=3,h,w
mean_c = img.mean(dim=1).unsqueeze(1)
# stx()
input = torch.cat([img,mean_c], dim=1)
x_1 = self.conv1(input)
illu_fea = self.depth_conv(x_1)
illu_map = self.conv2(illu_fea)
return illu_fea, illu_map
self.encoder
部分和文中的结构貌似不太一致?不知道是不是自己理解错了)class Denoiser(nn.Module):
def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
super(Denoiser, self).__init__()
self.dim = dim
self.level = level
# Input projection
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
# Encoder
self.encoder_layers = nn.ModuleList([])
dim_level = dim
for i in range(level):
self.encoder_layers.append(nn.ModuleList([
IGAB(
dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
]))
dim_level *= 2
# Bottleneck
self.bottleneck = IGAB(
dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])
# Decoder
self.decoder_layers = nn.ModuleList([])
for i in range(level):
self.decoder_layers.append(nn.ModuleList([
nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
kernel_size=2, padding=0, output_padding=0),
nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
IGAB(
dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
heads=(dim_level // 2) // dim),
]))
dim_level //= 2
# Output projection
self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, illu_fea):
"""
x: [b,c,h,w] x是feature, 不是image
illu_fea: [b,c,h,w]
return out: [b,c,h,w]
"""
# Embedding
fea = self.embedding(x)
# Encoder
fea_encoder = []
illu_fea_list = []
for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
fea = IGAB(fea,illu_fea) # bchw
illu_fea_list.append(illu_fea)
fea_encoder.append(fea)
fea = FeaDownSample(fea)
illu_fea = IlluFeaDownsample(illu_fea)
# Bottleneck
fea = self.bottleneck(fea,illu_fea)
# Decoder
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
fea = FeaUpSample(fea)
fea = Fution(
torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
illu_fea = illu_fea_list[self.level-1-i]
fea = LeWinBlcok(fea,illu_fea)
# Mapping
out = self.mapping(fea) + x
return out
class IG_MSA(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
):
super().__init__()
self.num_heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
self.pos_emb = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
GELU(),
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
)
self.dim = dim
def forward(self, x_in, illu_fea_trans):
"""
x_in: [b,h,w,c] # input_feature
illu_fea: [b,h,w,c] # mask shift? 为什么是 b, h, w, c?
return out: [b,h,w,c]
"""
b, h, w, c = x_in.shape
x = x_in.reshape(b, h * w, c)
q_inp = self.to_q(x)
k_inp = self.to_k(x)
v_inp = self.to_v(x)
illu_attn = illu_fea_trans # illu_fea: b,c,h,w -> b,h,w,c
q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
(q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))
v = v * illu_attn
# q: b,heads,hw,c
q = q.transpose(-2, -1)
k = k.transpose(-2, -1)
v = v.transpose(-2, -1)
q = F.normalize(q, dim=-1, p=2)
k = F.normalize(k, dim=-1, p=2)
attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
attn = attn * self.rescale
attn = attn.softmax(dim=-1)
x = attn @ v # b,heads,d,hw
x = x.permute(0, 3, 1, 2) # Transpose
x = x.reshape(b, h * w, self.num_heads * self.dim_head)
out_c = self.proj(x).view(b, h, w, c)
out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
0, 3, 1, 2)).permute(0, 2, 3, 1)
out = out_c + out_p
return out