本菜鸡本科毕设在FPGA上搞过图像滤波等算法,研究生期间虽然搞的是基于深度学习的图形学,但是主干网络用的还是卷积… 感觉自己代码能力还可以,基础还行,参赛之前还是比较自信的:
觉着看几篇顶会去噪的文章,复现借鉴一下应该能取得一个不错的结果,但是-------大概1000+人参赛,一多半没有提交的或者只提交个baseline,本菜最终100+ 额还没结束 明天结束了估计排名快接近200了 实在卷不动了
主要有以下三点问题:
虽然知道自己菜,但还是希望尝试一下。 吐槽结束,进入正题::::
谈一下收获,虽然困难挺多,但是收获也很多
图片是要切片的,一整张图太大了,网络稍大点,32G的显卡也会爆显存
把一张图分块为多个图,伪代码如下:
# 外层是一个循环 根据图像大小进行切片
tmp['imgs'] = data['imgs'][:, :, a:b, c:d] # batch 通道数 图片的长和宽
tmp['gts'] = data['gts'][:, :, a:b, c:d] # 标签
model.set_input(tmp) # 网络输入
我的网络主要借鉴的思想:
1.不直接学习端到端的像素值,而是学习噪声(网络更容易拟合?)
2.使用通道可分离的卷积,适当增加通道数(显存太小,跑起来速度很慢)
3.尝试增加卷积核大小(显存太小,跑起来速度很慢)
(比赛有模型大小限制)–增大通道和卷积核都会增加显存的使用,设备不行,故只增了通道数。具体的实现细节如下:
纯纯的Unet baseline修改而来
class Unet2(nn.Module):
def __init__(self, dim=4):
super(Unet2, self).__init__()
self.dims = [32, 64, 128, 256, 512]
self.ks = [3, 3, 3, 3, 3]
self.dims_up = self.dims[::-1]
self.ks_up = self.ks[-2::-1]
self.first_block = Block2(dim, self.dims[0], self.ks[0])
self.first_pool = nn.MaxPool2d(kernel_size=2) # AvgPool2d pnsr: 37.683, ssim: 0.902, score: 30.679, time: 52.650
for i, dim_in in enumerate(self.dims[:-2]):
dim_out = self.dims[i+1]
setattr(self, 'Block{}'.format(i), Block2(dim_in, dim_out, k=self.ks[i+1]))
setattr(self, 'pool{}'.format(i), nn.MaxPool2d(kernel_size=2))
self.conv_mid = Block2(self.dims[-2], self.dims[-1], self.ks[-1])
for i, dim_in in enumerate(self.dims_up[:-1]):
dim_out = self.dims_up[i+1]
setattr(self, 'ConvTrans{}'.format(i), nn.ConvTranspose2d(dim_in, dim_out, 2, stride=2, bias=True))
setattr(self, 'up_Block{}'.format(i), Block2(dim_in, dim_out, k=self.ks_up[i]))
self.last_conv = nn.Conv2d(self.dims[0], dim, 1, bias=True)
def forward(self, x):
n, c, h, w = x.shape
h_pad = 32 - h % 32 if not h % 32 == 0 else 0
w_pad = 32 - w % 32 if not w % 32 == 0 else 0
padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')
list_pools = []
x_bk = x
# 1.first Block
x = self.first_block(padded_image)
list_pools.append(x)
x = self.first_pool(x)
# 2.Blocks
for i, dim_in in enumerate(self.dims[:-2]):
x = getattr(self, 'Block{}'.format(i))(x)
list_pools.append(x)
x = getattr(self, 'pool{}'.format(i))(x)
x = self.conv_mid(x)
for i, dim_in in enumerate(self.dims_up[:-1]):
x = getattr(self, 'ConvTrans{}'.format(i))(x)
# tmp = list_pools.pop()
x = torch.cat([x, list_pools.pop()], 1)
x = getattr(self, 'up_Block{}'.format(i))(x)
# 3.last
x = self.last_conv(x)
out = x[:, :, :h, :w] + x_bk
return out
class Block2(nn.Module):
def __init__(self, dim_in, dim_out, k=3):
super(Block2, self).__init__()
self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=k, padding=k // 2, padding_mode='zeros', bias=True)
self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=k, padding=k // 2, padding_mode='zeros', bias=True)
def forward(self, x):
x = self.conv1(x)
x = self.leaky_relu(x)
x = self.conv2(x)
x = self.leaky_relu(x)
return x
def leaky_relu(self, x, a=0.2):
out = torch.max(a * x, x)
return out
我使用的网络 魔改ConvNet
class Our(nn.Module):
def __init__(self, dim=4):
super(Our, self).__init__()
self.dims = [128, 256, 512, 1024]
self.ks = [3, 3, 3, 3]
# 内存不够啊
# self.dims = [16, 32, 64, 128, 256]
# self.ks = [23, 23, 23, 17, 3]
######################################
self.dims_up = self.dims[::-1]
self.ks_up = self.ks[-2::-1]
self.first_block = Block(dim, self.dims[0], self.ks[0])
self.first_pool = nn.MaxPool2d(kernel_size=2)
for i, dim_in in enumerate(self.dims[:-2]):
dim_out = self.dims[i+1]
setattr(self, 'Block{}'.format(i), Block(dim_in, dim_out, k=self.ks[i+1]))
setattr(self, 'pool{}'.format(i), nn.MaxPool2d(kernel_size=2))
self.conv_mid = Block(self.dims[-2], self.dims[-1], self.ks[-1])
for i, dim_in in enumerate(self.dims_up[:-1]):
dim_out = self.dims_up[i+1]
setattr(self, 'ConvTrans{}'.format(i), nn.ConvTranspose2d(dim_in, dim_out, 2, stride=2))
setattr(self, 'up_Block{}'.format(i), Block(dim_in, dim_out, k=self.ks_up[i]))
self.last_ln = nn.LayerNorm(self.dims[0], eps=1e-6)
self.last_conv = nn.Linear(self.dims[0], dim)
def forward(self, x):
n, c, h, w = x.shape
h_pad = 32 - h % 32 if not h % 32 == 0 else 0
w_pad = 32 - w % 32 if not w % 32 == 0 else 0
padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')
list_pools = []
x_bk = x
# 1.first Block
x = self.first_block(padded_image)
list_pools.append(x)
x = self.first_pool(x)
# 2.Blocks
for i, dim_in in enumerate(self.dims[:-2]):
x = getattr(self, 'Block{}'.format(i))(x)
list_pools.append(x)
x = getattr(self, 'pool{}'.format(i))(x)
x = self.conv_mid(x)
for i, dim_in in enumerate(self.dims_up[:-1]):
x = getattr(self, 'ConvTrans{}'.format(i))(x)
# tmp = list_pools.pop()
x = torch.cat([x, list_pools.pop()], 1)
x = getattr(self, 'up_Block{}'.format(i))(x)
# 3.last
x = x.permute(0, 2, 3, 1).contiguous()
x = self.last_ln(x)
x = self.last_conv(x)
x = x.permute(0, 3, 1, 2).contiguous()
out = x[:, :, :h, :w] + x_bk
return out
class Block(nn.Module):
def __init__(self, dim_in, dim_out, k=9):
super(Block, self).__init__()
self.conv = nn.Conv2d(dim_in, dim_in, groups=dim_in, kernel_size=k, padding=k // 2)
self.ln = nn.LayerNorm(dim_in,eps=1e-6)
self.conv1x1up = nn.Linear(dim_in, dim_in * 2) #nn.Conv2d(dim, dim * 2, 1)
self.act = nn.GELU()
self.conv1x1dn = nn.Linear(dim_in * 2, dim_out) #nn.Conv2d(dim * 2, dim, 1)
self.w = nn.Parameter(torch.zeros(1))
# res
self.res_conv = nn.Conv2d(dim_in, dim_out, 1)
def forward(self, x):
identity = x
x = self.conv(x)
x = x.permute(0, 2, 3, 1).contiguous()
x = self.ln(x)
x = self.conv1x1up(x)
x = self.act(x)
x = self.conv1x1dn(x)
x = x.permute(0, 3, 1, 2).contiguous()
x = x * self.w
x = x + self.res_conv(identity)
return x
loss = torch.nn.L1Loss()
实测了一下,还是L1效果好啊
其它L2、SSIM之类的花里胡哨的效果并不理想 (毕竟是炼丹,可能只是不适合我的网络)
哈、我还试了一下传统的去噪,顺便使用纯python写了一个双边滤波(参考我以前matlab的代码),不得不说,还是深度学习yyds!
def bilateral_filter(img):
# 参考自己博客 matlab的实现 https://blog.csdn.net/qq_38204686/article/details/106929922
r = 20 # 窗口半径 核大小为 2*r + 1
sigma_space = 15.0 # 空间标准差
sigma_color = 10.0 # 相似标准差
w_space = np.zeros((2*r + 1, 2*r + 1))
for i in range(-r-1, r):
for j in range(-r-1, r):
tmp = i * i + j * j
w_space[i + r+1, j + r+1] = np.exp(-float(tmp) / (2 * sigma_space * sigma_space))
w_color = np.zeros((1, 256))
for i in range(256):
w_color[0, i] = np.exp(-float(i * i) / (2 * sigma_color * sigma_color))
# 开始滤波
height, width, channel = img.shape
dst_img = img.copy()
for h in range(r, height - r):
# s = time.time() 0.3s
for w in range(r, height - r):
for c in range(channel): # 通道遍历
p_c = img[h, w, c] # 像素值
p_win = img[h-r:h+r+1, w-r:w+r+1, c] # 窗口内所有像素
c_w = np.abs(p_win - p_c).astype(int)
c_w = w_color[0, c_w]
w_tmp = w_space * c_w
p_sum = p_win * w_tmp
p_sum = np.sum(p_sum) / np.sum(w_tmp)
dst_img[h, w, c] = p_sum
return dst_img