基于微软开源深度学习算法,用 Python 实现图像和视频修复

‍‍

基于微软开源深度学习算法,用 Python 实现图像和视频修复_第1张图片

基于微软开源深度学习算法,用 Python 实现图像和视频修复_第2张图片

作者 | 李秋键

编辑 | 夕颜

出品 | AI科技大本营(ID:rgznai100)

图像修复是计算机视觉领域的一个重要任务,在数字艺术品修复、公安刑侦面部修复等种种实际场景中被广泛应用。图像修复的核心挑战在于为缺失区域合成视觉逼真和语义合理的像素,要求合成的像素与原像素具有一致性。

传统的图像修复技术有基于结构和纹理两种方法。基于结构的图像修复算法具有代表性的是 Bertalmio 等提出的BSCB模型和 Shen 等提出的基于曲率扩散的修复模型 CDD。基于纹理的修复算法中具有代表性的有 Criminisi 等提出的基于 patch 的纹理合成算法。这两种传统的修复算法可以修复小块区域的破损,但是在破损区域越来越大时, 修复效果则直线下降, 并且修复结果存在图像模糊、结构扭曲、纹理不清晰和视觉不连贯等问题。

近年来,随着硬件设备等计算能力的不断提升, 以及深度学习技术在图像翻译、图像超分辨率、图 像修复等计算机视觉领域的迅速发展, 采用深度学习技术的修复方法能够捕获图像的高层语义信息, 与传统的修复方法相比, 具有良好的修复效果。故今天我们使用Python实现Bringing Old Photo Back to Life算法实现对图像和视频的修复。得到的模型评估效果如下:

基于微软开源深度学习算法,用 Python 实现图像和视频修复_第3张图片

基本介绍

传统的图像修复技术可以分为基于结构的图像修复技术和基于纹理的图像修复技术两大类。其中,变分偏微分方程模型是基于结构的图像修复技术的典型代表,由变分模型和偏微分方程模型组成。纹理合成是基于纹理的图像修复技术的典型代表。传统数字图像修复技术分类如下图所示。

传统的图像修复方法结果中存在语义信息不完整、图像模糊等问题,无法达到目前对图像修复的要求。而基于深度学习的图像修复算法能够捕获更多图像的高级特征,修复结果较好,所以经常用于图像修复。目前基于生成式对抗网络的图像修复是深度学习图像修复领域的一大研究热点,为图像修复技术的发展奠定了坚实的基础。而我们使用的算法就是基于深度学习的微软开源的Bringing Old Photo Back to Life去修复图像。

1.1 环境要求

本次环境使用的是Python3.6.5+windows平台。主要用的库有:

  • PyTorch模块。PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序。它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速,同时还支持动态神经网络,这一点是现在很多主流框架如TensorFlow都不支持的。PyTorch提供了两个高级功能:1.具有强大的GPU加速的张量计算(如Numpy) 2.包含自动求导系统的深度神经网络 除了Facebook之外,Twitter、GMU和Salesforce等机构都采用了PyTorch。

  • pillow模块。Pillow是Python里的图像处理库(PIL:Python Image Library),提供了了广泛的文件格式支持,强大的图像处理能力,主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。

  • Numpy模块。Numpy是应用Python进行科学计算时的基础模块。它是一个提供多维数组对象的Python库,除此之外,还包含了多种衍生的对象(比如掩码式数组(masked arrays)或矩阵)以及一系列的为快速计算数组而生的例程,包括数学运算,逻辑运算,形状操作,排序,选择,I/O,离散傅里叶变换,基本线性代数,基本统计运算,随机模拟等等。

  • collections这个模块实现了特定目标的容器,以提供Python标准内建容器 dict、list、set、tuple 的替代选择。Counter:字典的子类,提供了可哈希对象的计数功能;defaultdict:字典的子类,提供了一个工厂函数,为字典查询提供了默认值;OrderedDict:字典的子类,保留了他们被添加的顺序;namedtuple:创建命名元组子类的工厂函数;deque:类似列表容器,实现了在两端快速添加(append)和弹出(pop);ChainMap:类似字典的容器类,将多个映射集合到一个视图里面。

修复模型算法

本文所使用的Bringing Old Photo Back to Life算法流程分别为全局修复、脸部检测、脸部特征加强和特征融合。其中隐空间修复网络采用局部-全局视野融合,其中全局支路采用 nonlocal 模块大大增强处理视野。我们对局部破损图片建立了数据集,训练网络预测破损区域,该破损区域显式的送入 nonlocal 模块,并设置模块感受野为非破损区域

2.1 全局视野修复

本文的模型主要由三个部分组成两个变分自编码器(variational-autoencoder,VAE)和一个latent space 映射网络,每个部分都可以看作是单独的一个模块。下面将介绍网络设计的思想和不同部分的作用。

基于微软开源深度学习算法,用 Python 实现图像和视频修复_第4张图片

模型使用了两个 VAE:

第一个 VAE 用于将合成的老照片(模糊、磨损)进行编码到隐空间。

第二个 VAE 用于将对应的干净的老照片进行编码。

然后,在隐空间学习从污损的老照片到干净照片的映射。

就这样,实现了一个老照片的修复算法。

这个有点像在学习控制图片清晰、磨损的一个特征表示,通过控制这个特征,可以达到修复破损照片的目的。

关键代码如下:

model = networks.UNet(in_channels=1, out_channels=1, depth=4, conv_num=2, wf=6, padding=True, batch_norm=True, up_mode="upsample",with_tanh=False, sync_bn=True, antialiasing=True,
)
for image_name in imagelist:
    idx += 1
    print("processing", image_name)
    results = []
    scratch_image = Image.open(os.path.join(config.test_path, image_name)).convert("RGB")
    w, h = scratch_image.size
    transformed_image_PIL = data_transforms(scratch_image, config.input_size)
    scratch_image = transformed_image_PIL.convert("L")
    scratch_image = tv.transforms.ToTensor()(scratch_image)
    scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)
    scratch_image = torch.unsqueeze(scratch_image, 0)
    scratch_image = scratch_image.to(config.GPU)
    P = torch.sigmoid(model(scratch_image))
    P = P.data.cpu()
    tv.utils.save_image(
        (P >= 0.4).float(),
        os.path.join(output_dir, image_name[:-4] + ".png",),
        nrow=1,
        padding=0,
        normalize=True,
    )
    transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))

2.2 局部脸部修复加强

脸部特征的加强使用pixpix2模型对脸部二次修复。其中, Pix2Pix模型由Isola等于2017年提出, 它由U-Net和PatchGAN组成, 分别充当Pix2Pix模型中的生成器和判别器。该模型使用户只需提供一个草图便能生成一个与之对应的高质量图像; 对应到图像着色工作中, 网络接收真实图像的亮度信息, 对亮度信息进行特征提取并预测图像颜色值。

关键代码:

def create_optimizers(self, opt):
    G_params = list(self.netG.parameters())
    if opt.use_vae:
        G_params += list(self.netE.parameters())
    if opt.isTrain:
        D_params = list(self.netD.parameters())
    beta1, beta2 = opt.beta1, opt.beta2
    if opt.no_TTUR:
        G_lr, D_lr = opt.lr, opt.lr
    else:
        G_lr, D_lr = opt.lr / 2, opt.lr * 2
    optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
    optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))
    return optimizer_G, optimizer_D
def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False):
    z = None
    KLD_loss = None
    if self.opt.use_vae:
        z, mu, logvar = self.encode_z(real_image)
        if compute_kld_loss:
            KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
    fake_image = self.netG(input_semantics, degraded_image, z=z)
    assert (
        not compute_kld_loss
    ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False"
    return fake_image, KLD_loss
def discriminate(self, input_semantics, fake_image, real_image):
    if self.opt.no_parsing_map:
        fake_concat = fake_image
        real_concat = real_image
    else:
        fake_concat = torch.cat([input_semantics, fake_image], dim=1)
        real_concat = torch.cat([input_semantics, real_image], dim=1)
    fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
    discriminator_out = self.netD(fake_and_real)
    pred_fake, pred_real = self.divide_pred(discriminator_out)
    return pred_fake, pred_real

源代码:https://pan.baidu.com/s/1lAzmWvAEyxi6RFsLpA5l_Q

提取码:osuh

推荐阅读
  • 后疫情时代,RTC期待新的场景大爆

  • 蓝色起源载人火箭7月首飞,贝索斯即将实现儿时愿望

  • 干货!机器学习中,如何优化数据性

  • 你的 AI 算法模型安全吗?来 AI 安全测试基准平台测试

基于微软开源深度学习算法,用 Python 实现图像和视频修复_第5张图片

点个“在看”,宠我一下

你可能感兴趣的:(算法,神经网络,大数据,机器学习,人工智能)