这几天在折腾一项大作业
使用深度学习来实现图像修复
图像修复这块在CV界还是有很多坑需要去填的
从2016年横空出世的自解编码器,到英伟达退出的Pconv模型,再到注意力机制
各种玩的很花的模型都能进一步提升修复能力
下面我就介绍一项自编解码器,即AutoEncoderDecoder
之前一直不知道该怎么去做
然后参考了一篇基于pytorch的
https://blog.csdn.net/qq_39938666/article/details/88363042
以该文提出的结构为基础
我造了一些轮子,使用mxnet进行实现
首先先通过一系列卷积层提取特征
这是一个编码器即Encoder
然后再通过一系列反卷积进行上采样
恢复到原始图像的形状
这是一个解码器,即decoder
损失值通过计算原始图像和解码器解码出来的图像的均方误差值
对此进行梯度计算
对自编解码网络进行优化
我这里把整个目录结构放到Github上
https://github.com/MARD1NO/Inpaint_AutoEncoderDecoder
整体代码如下
import mxnet
from mxnet import gluon, init, nd, autograd, image
from mxnet.gluon import nn, data as gdata, loss as gloss, utils as gutils
import matplotlib.pyplot as plt
import time
import os
class AutoEncoder(nn.Block):
def __init__(self, **kwargs):
super(AutoEncoder, self).__init__(**kwargs)
self.encoder_net = nn.Sequential()
self.encoder_net.add(
nn.Conv2D(16, kernel_size=3, strides=1, padding=1),
nn.PReLU(),
nn.Conv2D(16, kernel_size=4, strides=2, padding=1),
nn.PReLU(),
nn.Conv2D(32, kernel_size=3, strides=1, padding=1),
nn.PReLU(),
nn.Conv2D(32, kernel_size=4, strides=2, padding=1),
nn.PReLU(),
nn.Conv2D(64, kernel_size=3, strides=1, padding=1),
nn.PReLU(),
nn.Conv2D(64, kernel_size=4, strides=2, padding=1),
nn.PReLU(),
nn.Conv2D(128, kernel_size=4, strides=2, padding=1),
nn.PReLU(),
nn.Conv2D(128, kernel_size=4, strides=2, padding=1),
nn.PReLU(),
nn.Conv2D(256, kernel_size=3, strides=1, padding=1),
nn.PReLU(),
nn.Conv2D(256, kernel_size=4, strides=2, padding=1),
) # return ,
self.decoder_net = nn.Sequential()
self.decoder_net.add(
nn.Conv2DTranspose(256, kernel_size=3, strides=3, padding=2),
nn.PReLU(),
nn.Conv2DTranspose(128, kernel_size=3, strides=1, padding=1),
nn.PReLU(),
nn.Conv2DTranspose(128, kernel_size=2, strides=2),
nn.PReLU(),
nn.Conv2DTranspose(64, kernel_size=2, strides=2),
nn.PReLU(),
nn.Conv2DTranspose(64, kernel_size=2, strides=2),
nn.PReLU(),
nn.Conv2DTranspose(32, kernel_size=1, strides=1),
nn.PReLU(),
nn.Conv2DTranspose(32, kernel_size=2, strides=2),
nn.PReLU(),
nn.Conv2DTranspose(16, kernel_size=1, strides=1),
nn.PReLU(),
nn.Conv2DTranspose(16, kernel_size=2, strides=2),
nn.PReLU(),
nn.Conv2DTranspose(3, kernel_size=1, strides=1),
nn.Activation('tanh')
) # return )
def forward(self, x):
encoded = self.encoder_net(x)
decoded = self.decoder_net(encoded)
return decoded
rgb_std = [0.485, 0.456, 0.406]
rgb_mean = [0.229, 0.224, 0.225]
def postpreprocess(img):
# return (img.transpose((1, 2, 0))*rgb_std + rgb_mean).clip(0, 1)
return (img.transpose((1, 2, 0))).clip(0, 1)
net = AutoEncoder()
net.initialize()
train_data = gdata.vision.ImageFolderDataset(r'./train')
# normalize = gdata.vision.transforms.Normalize(
# [0.485, 0.456, 0.406],
# [0.229, 0.224, 0.225]
# )
train_augs = gdata.vision.transforms.Compose([
gdata.vision.transforms.Resize(256),
# gdata.vision.transforms.RandomFlipLeftRight(),
gdata.vision.transforms.ToTensor(),
# normalize
])
batch_size = 20
max_epochs = 100
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':0.001})
ctx = mxnet.cpu()
def square_loss(Y_hat, Y):
return (Y_hat[:, :, :, :] - Y[:, :, :, :]).square().mean()
class GetData(gdata.Dataset):
def __init__(self, path1):
super(GetData, self).__init__()
self.path1 = path1
# self.path2 = path2
self.dataset1 = []
# self.dataset2 = []
self.dataset1.extend(os.listdir(path1))
# self.dataset2.extend(os.listdir(path2))
def __getitem__(self, idx):
img_path1 = self.dataset1[idx]
# img_path2 = self.dataset2[idx]
img_1 = image.imread(os.path.join(self.path1, img_path1))
# img_2 = image.imread(os.path.join(self.path2, img_path2))
return img_1
# , img_2
def __len__(self):
return len(self.dataset1)
dataset = GetData('./train/true')
dataset_2 = GetData('./train/masked')
batch_size = 20
max_epochs = 4000
train_iter = gdata.DataLoader(
dataset.transform_first(train_augs), batch_size
)
masked_iter = gdata.DataLoader(
dataset_2.transform_first(train_augs), batch_size
)
for epoch in range(max_epochs):
train_l_sum, train_acc_sum, n, m, start = 0.0, 0.0, 0, 0, time.time()
for true, masked in zip(train_iter, masked_iter):
m += 1
Xs, ys = masked, true
with autograd.record():
decoded = net(Xs)
print("decoded!!!", decoded.shape)
print(ys.shape)
ls = square_loss(decoded, ys)
print(ls.shape)
ls.backward()
trainer.step(batch_size)
nd.waitall()
train_l_sum += ls.sum().asscalar()
n += len(ls)
print('epoch %d, loss %.4f, time %.1f sec'
% (epoch + 1, train_l_sum / n, time.time() - start))
if m == 2:
decode_img = decoded[0]
plt.imsave('./decoded/decode_img_%d'%epoch, postpreprocess(decode_img.asnumpy()))
true_img = ys[0]
plt.imsave('./decoded/true_img_%d'%epoch, postpreprocess(true_img.asnumpy()))
if epoch % 500 == 0:
net.save_parameters('./ae_params/autoencoder_net_%d'%epoch)
虽然mxnet内部提供了一个数据导入的类
但它是从文件夹流入,分配的分别是图片以及标签
而我们进行图像修复并不需要用到标签,我们需要计算的是内容的loss值
因此我们只能实现一个数据集类
其中这段代码
class GetData(gdata.Dataset):
def __init__(self, path1):
super(GetData, self).__init__()
self.path1 = path1
# self.path2 = path2
self.dataset1 = []
# self.dataset2 = []
self.dataset1.extend(os.listdir(path1))
# self.dataset2.extend(os.listdir(path2))
def __getitem__(self, idx):
img_path1 = self.dataset1[idx]
# img_path2 = self.dataset2[idx]
img_1 = image.imread(os.path.join(self.path1, img_path1))
# img_2 = image.imread(os.path.join(self.path2, img_path2))
return img_1
# , img_2
def __len__(self):
return len(self.dataset1)
继承自mxnet.data.Dataset类
我们需要实现其中的get_item和 len方法
所以我们自然而然地可以在init初始方法里,初始化我们图片的路径
然后用os.listdir列出目录下所有图片的名字
get_item里面我们有个参数是idx,是一个默认的索引
我们要返回我们的图片对象
就利用image的imread方法,用os.path.join将目录与图片名组合起来,进行读取,然后返回该对象
len方法是来计算长度的,我们就可以用len返回我们的dataset1列表的长度
我们的网络定义就是 class net这个类,继承了Block类来完成一系列的前向计算
首先是定义了一系列的卷积操作,后接一个Prelu,我这里不采用池化层,因为这样会丢失部分信息,我中间夹杂了一些步长为2的卷积操作来对特征图进行缩放。
编码这一块,我输入的图像假设为1 x 3 x 256 x 256
编码完成后输出向量为 1 x 256 x 4 x 4
然后开始进行解码器编写
其中操作也比较类似,一系列的转置卷积层后接Prelu,最后一层使用tanh激活
输出向量还原回我们输入图像的shape 即 1 x 3 x 256 x 256
前向计算我们是这样定义的
首先输入图像x,经过编码器获取中间输出,再将这个中间输入传入到解码器进行解码
这个过程十分简单,当然你也可以在网络结构里增加一些SqueezeExcitation结构,但是这可能会略微增加一些计算开销
自带的loss里面没有均方误差损失,当然你也可以使用其他的损失函数
这里均方误差损失常用于图像风格迁移的内容损失计算
mxnet的另一个好处就是与numpy结合程度很高
可以不用继承很复杂的类
用numpy就能进行实现
def square_loss(Y_hat, Y):
return (Y_hat[:, :, :, :] - Y[:, :, :, :]).square().mean()
我们计算平方值后利用mean方法进行平均
这里我还没有对其他激活函数做实验
看了那么多关于图像修复以及博客
发现大部分都是用Prelu进行一个gamma值的学习
在解码器最后使用一层tanh进行激活
大家也可以自己改一下模型
使用Mish或者Swish函数进行激活
其他方面就是一些图像预处理,以及图像输出,这方面代码里面写的很清楚。最后就是一个mxnet的参数保存,我每隔400保存一次param
这个自编码解码器训练起来还是比较麻烦的
因为也没有用到什么新兴的结构
但效果有博主实测了还不错
另外我会继续进行尝试不同激活函数组合能不能再提高模型