噪声图片 y 可以表示为 干净图片 x 和噪声 n的叠加
y = x + n y = x + n y=x+n
使用单个输入进行预测 的原理是:
F θ ( . ) : y → x F_{\theta}(.) \; : \; y \rightarrow x Fθ(.):y→x
常规监督神经网络训练
m i n θ ∑ i L ( F θ ( x ( i ) ) , y ( i ) ) \underset{\theta}{min} \sum_i L(F_{\theta}(x^{(i)}),y^{(i)}) θmini∑L(Fθ(x(i)),y(i))
其中 F θ F_{\theta} Fθ是神经网络, θ \theta θ是网络参数;但是就从一个神经网络训练的过程来看
M S E = b i a s 2 + v a r i a n c e MSE = bias ^2 + variance MSE=bias2+variance
当训练数据减少的时候,variance会极剧增加。blind-spot技术可以用来阻止这种过拟合现象,但单个样本训练带来的大的variance是无法解决的。这也是基于blind-spot的神经网络 N2V和N2S在单个图片上效果不好的原因。
Dropout技术是一种广泛应用的正则化技术,同时其可以提供一定程度的不确定性估计,避免出现恒等映射。盲点策略通过对噪声数据随机采样合成多个不同的噪声数据版本,并在这些替换样本上计算损失。因此本文提出的一个策略就变为了:在输入图像的伯努利采样实例上定义自预测损失函数
y ^ [ k ] = { y [ k ] , w i t h p r o b a b i l i t y p ; 0 , w i t h p r o b a b i l i t y 1 − p \hat{y}[k] = \begin{cases} y[k] &,with \; probability \; p; \\ 0 &,with \; probability \; 1-p \end{cases} y^[k]={y[k]0,withprobabilityp;,withprobability1−p
采样两个 Bernoulli 采样实例数据集 y ^ m {\hat{y}_m} y^m和 y n ^ \hat{y_n} yn^
训练过程,最小化下面这个损失
m i n θ ∑ m L ( F θ ( y ^ m ) , y − y ^ m ) \underset{\theta}{min} \sum_m L(F_{\theta}(\hat{y}_m),y-\hat{y}_m) θminm∑L(Fθ(y^m),y−y^m)
测试过程:在另一个采样数据集上, 得到每一个 y n y_n yn对应的预测结果,然后求一个平均值得到最后的去噪数据
Encoder结构
Decoder 结构:
部分细节:
结构和 Noise2Noise结构基本相似,不同点在于:
注意,这里是使用的 部分卷积网络,所以使用了 NVIDIA的实现,
import torch
import torch.nn.functional as F
from torch import nn, cuda
from torch.autograd import Variable
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False
#####Yize's fixes
self.multi_channel = True
self.return_mask = True
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
else:
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
self.last_size = (None, None, None, None)
self.update_mask = None
self.mask_ratio = None
def forward(self, input, mask_in=None):
assert len(input.shape) == 4
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
else:
mask = mask_in
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
# for mixed precision training, change 1e-8 to 1e-6
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class EncodeBlock(nn.Module):
def __init__(self,in_channel,out_channel,flag):
super(EncodeBlock,self).__init__()
self.conv = PartialConv2d(in_channel, out_channel, kernel_size = 3, padding = 1)
self.nonlinear = nn.LeakyReLU(0.1)
self.MaxPool = nn.MaxPool2d(2)
self.flag = flag
def forward(self, x, mask_in):
out1, mask_out = self.conv(x, mask_in = mask_in)
out2 = self.nonlinear(out1)
if self.flag:
out = self.MaxPool(out2)
mask_out = self.MaxPool(mask_out)
else:
out = out2
return out, mask_out
class DecodeBlock(nn.Module):
def __init__(self, in_channel, mid_channel, out_channel, final_channel = 3, p = 0.7, flag = False):
super(DecodeBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channel,mid_channel,kernel_size=3,padding=1)
self.conv2 = nn.Conv2d(mid_channel,out_channel,kernel_size=3,padding=1)
self.conv3 = nn.Conv2d(out_channel,final_channel,kernel_size=3,padding=1)
self.nonlinear1 = nn.LeakyReLU(0.1)
self.nonlinear2 = nn.LeakyReLU(0.1)
self.sigmoid = nn.Sigmoid()
self.flag = flag
self.Dropout = nn.Dropout(p)
def forward(self,x):
out1 = self.conv1(self.Dropout(x))
out2 = self.nonlinear1(out1)
out3 = self.conv2(self.Dropout(out2))
out4 = self.nonlinear2(out3)
if self.flag:
out5 = self.conv3(self.Dropout(out4))
out = self.sigmoid(out5)
else:
out = out4
return out
class self2self(nn.Module):
def __init__(self,in_channel,p):
super(self2self,self).__init__()
self.EB0 = EncodeBlock(in_channel,out_channel=48,flag=False)
self.EB1 = EncodeBlock(48,48,flag=True)
self.EB2 = EncodeBlock(48,48,flag=True)
self.EB3 = EncodeBlock(48,48,flag=True)
self.EB4 = EncodeBlock(48,48,flag=True)
self.EB5 = EncodeBlock(48,48,flag=True)
self.EB6 = EncodeBlock(48,48,flag=False)
self.DB1 = DecodeBlock(in_channel=96,mid_channel=96,out_channel=96,p=p)
self.DB2 = DecodeBlock(in_channel=144,mid_channel=96,out_channel=96,p=p)
self.DB3 = DecodeBlock(in_channel=144,mid_channel=96,out_channel=96,p=p)
self.DB4 = DecodeBlock(in_channel=144,mid_channel=96,out_channel=96,p=p)
self.DB5 = DecodeBlock(in_channel=96+in_channel,mid_channel=64,out_channel=32,p=p,flag=True)
self.Upsample = nn.Upsample(scale_factor=2,mode='bilinear')
self.concat_dim = 1
def forward(self,x,mask):
out_EB0,mask = self.EB0(x,mask) # [3,w,h] -> [48,w,h]
out_EB1,mask = self.EB1(out_EB0,mask_in=mask) # [48,w,h] -> [48,w/2,h/2]
out_EB2,mask = self.EB2(out_EB1,mask_in=mask) # [48,w/2,h/2] -> [48,w/4,h/4]
out_EB3,mask = self.EB3(out_EB2,mask_in=mask) # [48,w/4,h/4] -> [48,w/8,h/8]
out_EB4,mask = self.EB4(out_EB3,mask_in=mask) # [48,w/8,h/8] -> [48,w/16,h/16]
out_EB5,mask = self.EB5(out_EB4,mask_in=mask) # [48,w/16,h/16] -> [48,w/32,h/32]
out_EB6,mask = self.EB6(out_EB5,mask_in=mask) # [48,w/32,h/32] -> [48,w/32,h/32]
out_EB6_up = self.Upsample(out_EB6) # [48,w/32,h/32] -> [48,w/16,h/16]
in_DB1 = torch.cat((out_EB6_up,out_EB4),self.concat_dim) # [48,w/16,h/16] -> [96,w/16,h/16]
out_DB1 = self.DB1((in_DB1)) # [96,w/16,h/16] -> [96,w/16,h/16]
out_DB1_up = self.Upsample(out_DB1) # [96,w/16,h/16] -> [96,w/8,h/8]
in_DB2 = torch.cat((out_DB1_up,out_EB3),self.concat_dim) # [96,w/8,w/8] -> [144,w/8,w/8]
out_DB2 = self.DB2((in_DB2)) # [144,w/8,w/8] -> [96,w/8,w/8]
out_DB2_up = self.Upsample(out_DB2) # [96,w/8,h/8] -> [96,w/4,h/4]
in_DB3 = torch.cat((out_DB2_up,out_EB2),self.concat_dim) # [96,w/4,w/4] -> [144,w/4,w/4]
out_DB3 = self.DB2((in_DB3)) # [144,w/4,w/4] -> [96,w/4,w/4]
out_DB3_up = self.Upsample(out_DB3) # [96,w/4,h/4] -> [96,w/2,h/2]
in_DB4 = torch.cat((out_DB3_up, out_EB1),self.concat_dim) # [96,w/2,w/2] -> [144,w/2,w/2]
out_DB4 = self.DB4((in_DB4)) # [144,w/2,w/2] -> [96,w/2,w/2]
out_DB4_up = self.Upsample(out_DB4) # [96,w/2,h/2] -> [96,w,h]
in_DB5 = torch.cat((out_DB4_up, x),self.concat_dim) # [96,w,h] -> [96+c,w,h]
out_DB5 = self.DB5(in_DB5) # [96+c,w,h] -> [32,w,h]
return out_DB5
model = self2self(3,0.3)
model
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torchvision.transforms as T
import cv2
from PIL import Image
from tqdm import tqdm
# 图片加载
img = np.array(Image.open("5.png"))
plt.figure()
plt.imshow(img)
plt.show()
img.shape
# 参数设置
##Enable GPU
USE_GPU = True
dtype = torch.float32
if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print('using device:', device)
learning_rate = 1e-4
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
w,h,c = img.shape
p=0.3
NPred=100
slice_avg = torch.tensor([1,3,512,512]).to(device)
# 训练迭代
def image_loader(image, device, p1, p2):
"""
load image and returns cuda tensor
"""
loader = T.Compose([
T.RandomHorizontalFlip(torch.round(torch.tensor(p1))),
T.RandomVerticalFlip(torch.round(torch.tensor(p2))),
T.ToTensor()])
image = Image.fromarray(image.astype(np.uint8))
image = loader(image).float()
if not torch.is_tensor(image):
image = torch.tensor(image)
image = image.unsqueeze(0) #this is for VGG, may not be needed for ResNet
return image.to(device)
pbar = tqdm(range(500000))
for itr in pbar:
# 不知道这个采样是否正确,是不是需要在每一个通道都分别进行均匀采样?
p_mtx = np.random.uniform(size=[img.shape[0],img.shape[1],img.shape[2]])
mask = (p_mtx>p).astype(np.double)
img_input = img
y = img
p1 = np.random.uniform(size=1)
p2 = np.random.uniform(size=1)
# 加载输入图片(根据概率进行翻转)
img_input_tensor = image_loader(img_input, device, p1, p2)
# 对原始图片进行相同操作(翻转)
y = image_loader(y, device, p1, p2)
# mask为伯努利采样结果
mask = np.expand_dims(np.transpose(mask,[2,0,1]),0)
mask = torch.tensor(mask).to(device, dtype=torch.float32)
# 网络推理
model.train()
img_input_tensor = img_input_tensor*mask
output = model(img_input_tensor, mask)
# 损失函数
# loss = torch.sum((output+img_input_tensor-y)*(output+img_input_tensor-y)*(1-mask))/torch.sum(1-mask)
loss = torch.sum((output-y)*(output-y)*(1-mask))/torch.sum(1-mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description("iteration {}, loss = {:.4f}".format(itr+1, loss.item()*100))
if (itr+1)%1000 == 0:
model.eval()
sum_preds = np.zeros((img.shape[0],img.shape[1],img.shape[2]))
for j in range(NPred):
p_mtx = np.random.uniform(size=img.shape)
mask = (p_mtx>p).astype(np.double)
img_input = img*mask
img_input_tensor = image_loader(img_input, device, 0.1, 0.1)
mask = np.expand_dims(np.transpose(mask,[2,0,1]),0)
mask = torch.tensor(mask).to(device, dtype=torch.float32)
output_test = model(img_input_tensor,mask)
sum_preds[:,:,:] += np.transpose(output_test.detach().cpu().numpy(),[2,3,1,0])[:,:,:,0]
avg_preds = np.squeeze(np.uint8(np.clip((sum_preds-np.min(sum_preds)) / (np.max(sum_preds)-np.min(sum_preds)), 0, 1) * 255))
write_img = Image.fromarray(avg_preds)
write_img.save("./examples/images/Self2self-"+str(itr+1)+".png")
torch.save(model.state_dict(),'./examples/models/model-'+str(itr+1))
展示不同次数的结果:
1000,10000,20000,30000次迭代
从我自己可能会用到的地方进行 评价 (不是评价啊哈,大佬的工作真的非常棒,就是从我们迁移应用的角度看待)
一些小问题: