DCGAN应用: Semantic Image Inpainting with Deep Generative Models

本篇blog内容基于残缺图像补全Semantic Image Inpainting with Deep Generative Models(CVPR2017),使用数据库celabA,包含了202599张头像图片。


一、数据集

每行包括5张图片:第一列是数据库原始图片;第二列是随机取出80%的像素点图片;第三列是使用补全方法对第二列的修复结果;第四列是原始数据中间被扣掉一大块的图片;第五列是使用补全图片方法对第四列修复的结果。

 

二、使用DCGAN补全图片

要使用生成网络补全图像需要两个条件:

  • 使用DCGAN在大量图像数据训练后能够手册“骗过”判别器的图片
  • 生成图像与原图像为丢失部分的差值要尽量小

 

三、损失函数

论文里给出了两个损失函数:

  • 丢失信息图片相关的上下文损失(contextual loss),是生成图片与原始图片在为丢失区域的差距大小,M是7个像素大小的一个遮罩。

                                                         \\L_c(z\mid y,M)=\left \| W\odot (G(z)-y) \right \|_1\\ \odot enotes\ the\ element\_ wise\ multiplication

  • DCGAN本身的感知损失(perceptual loss),这是对于DCGAN本身在大量人脸数据集上训练的损失函数,与之前GAN中生成器的损失函数一致。

                                                         L_p(z)=\lambda\ log(1-D(G(z)))

最终的损失函数为

                                                         \\L(z)=L_{contextual}(z)+\lambda L_{perceptual}(z)

                                                                        \hat z= \mathop{\arg\min}_{\lambda}L(z)

其中\lambda为超参数,用来调节两个损失函数的重要程度,\hat z是我们要求的生成器输入

                                                     x_{reconstructed}=M\odot Y+(1-M)\odot G(\hat z)

 

四、实验

1. 填补随机缺失像素

论文讨论了多种方案与DCGAN的补全方法对比,在效果中后者优于前者,第一列为原始图片,第二列为随机丢失后的图片,第五列是使用DCGAN补全结果,第三和第四列使用了TVminimization和low rank minimization,可以看到后两种十分模糊。

DCGAN应用: Semantic Image Inpainting with Deep Generative Models_第1张图片

 

2. 中间镂空图片填充

很明显使用DCGAN效果更自然,而其他方法则有明显拼接痕迹

DCGAN应用: Semantic Image Inpainting with Deep Generative Models_第2张图片

 

五、完整代码

1. 初始化参数

import tensorflow as tf
import numpy as np
import external.poissonblending as blending
from scipy.signal import convolve2d


class ModelInpaint():
    def __init__(self, modelfilename, config,
                 model_name='dcgan',
                 gen_input='z:0', gen_output='Tanh:0', gen_loss='Mean_2:0',
                 disc_input='real_images:0', disc_output='Sigmoid:0',
                 z_dim=100, batch_size=64):
        self.config = config

        self.batch_size = batch_size
        self.z_dim = z_dim
        self.graph, self.graph_def = ModelInpaint.loadpb(modelfilename,
                                                         model_name)

        self.gi = self.graph.get_tensor_by_name(model_name+'/'+gen_input)
        self.go = self.graph.get_tensor_by_name(model_name+'/'+gen_output)
        self.gl = self.graph.get_tensor_by_name(model_name+'/'+gen_loss)
        self.di = self.graph.get_tensor_by_name(model_name+'/'+disc_input)
        self.do = self.graph.get_tensor_by_name(model_name+'/'+disc_output)

        self.image_shape = self.go.shape[1:].as_list()

        self.l = config.lambda_p

        self.sess = tf.Session(graph=self.graph)

        self.init_z()

    def init_z(self):
        """Initializes latent variable z"""
        self.z = np.random.randn(self.batch_size, self.z_dim)

    def sample(self, z=None):
        """GAN sampler. Useful for checking if the GAN was loaded correctly"""
        if z is None:
            z = self.z
        sample_out = self.sess.run(self.go, feed_dict={self.gi: z})
        return sample_out

2. 预训练模型和正常训练

    def preprocess(self, images, imask, useWeightedMask = True, nsize=7):
        """Default preprocessing pipeline
        Prepare the data to be fed to the network. Weighted mask is computed
        and images and masks are duplicated to fill the batch.

        Arguments:
            image - input image
            mask - input mask

        Returns:
            None
        """
        images = ModelInpaint.imtransform(images)
        if useWeightedMask:
            mask = ModelInpaint.createWeightedMask(imask, nsize)
        else:
            mask = imask
        mask = ModelInpaint.create3ChannelMask(mask)
        
        bin_mask = ModelInpaint.binarizeMask(imask, dtype='uint8')
        self.bin_mask = ModelInpaint.create3ChannelMask(bin_mask)

        self.masks_data = np.repeat(mask[np.newaxis, :, :, :],
                                    self.batch_size,
                                    axis=0)

        #Generate multiple candidates for completion if single image is given
        if len(images.shape) is 3:
            self.images_data = np.repeat(images[np.newaxis, :, :, :],
                                         self.batch_size,
                                         axis=0)
        elif len(images.shape) is 4:
            #Ensure batch is filled
            num_images = images.shape[0]
            self.images_data = np.repeat(images[np.newaxis, 0, :, :, :],
                                         self.batch_size,
                                         axis=0)
            ncpy = min(num_images, self.batch_size)
            self.images_data[:ncpy, :, :, :] = images[:ncpy, :, :, :].copy()

    def postprocess(self, g_out, blend = True):
        """Default post processing pipeline
        Applies poisson blending using binary mask. (default)

        Arguments:
            g_out - generator output
            blend - Use poisson blending (True) or alpha blending (False)
        """
        images_out = ModelInpaint.iminvtransform(g_out)
        images_in = ModelInpaint.iminvtransform(self.images_data)

        if blend:
            for i in range(len(g_out)):
                images_out[i] = ModelInpaint.poissonblending(
                    images_in[i], images_out[i], self.bin_mask
                )
        else:
            images_out = np.multiply(images_out, 1-self.masks_data) \
                         + np.multiply(images_in, self.masks_data)

        return images_out

3. 填充缺失图像

    def build_inpaint_graph(self):
        """Builds the context and prior loss objective"""
        with self.graph.as_default():
            self.masks = tf.placeholder(tf.float32,
                                        [None] + self.image_shape,
                                        name='mask')
            self.images = tf.placeholder(tf.float32,
                                         [None] + self.image_shape,
                                         name='images')
            self.context_loss = tf.reduce_sum(
                    tf.contrib.layers.flatten(
                        tf.abs(tf.multiply(self.masks, self.go) -
                               tf.multiply(self.masks, self.images))), 1
                )

            self.perceptual_loss = self.gl
            self.inpaint_loss = self.context_loss + self.l*self.perceptual_loss
            self.inpaint_grad = tf.gradients(self.inpaint_loss, self.gi)

    def inpaint(self, image, mask, blend=True):
        """Perform inpainting with the given image and mask with the standard
        pipeline as described in paper. To skip steps or try other pre/post
        processing, the methods can be called seperately.

        Arguments:
            image - input 3 channel image
            mask - input binary mask, single channel. Nonzeros values are 
                   treated as 1
            blend - Flag to apply Poisson blending on output, Default = True

        Returns:
            post processed image (merged/blneded), raw generator output
        """
        self.build_inpaint_graph()
        self.preprocess(image, mask)

        imout = self.backprop_to_input()

        return self.postprocess(imout, blend), imout

    def backprop_to_input(self, verbose=True):
        """Main worker function. To be called after all initilization is done.
        Performs backpropagation to input using (accelerated) gradient descent
        to obtain latent space representation of target image

        Returns:
            generator output image
        """
        v = 0
        for i in range(self.config.nIter):
            out_vars = [self.inpaint_loss, self.inpaint_grad, self.go]
            in_dict = {self.masks: self.masks_data,
                       self.gi: self.z,
                       self.images: self.images_data}

            loss, grad, imout = self.sess.run(out_vars, feed_dict=in_dict)

            v_prev = np.copy(v)
            v = self.config.momentum*v - self.config.lr*grad[0]
            self.z += (-self.config.momentum * v_prev +
                       (1 + self.config.momentum) * v)
            self.z = np.clip(self.z, -1, 1)

            if verbose:
                print('Iteration {}: {}'.format(i, np.mean(loss)))

        return imout

    @staticmethod
    def loadpb(filename, model_name='dcgan'):
        """Loads pretrained graph from ProtoBuf file

        Arguments:
            filename - path to ProtoBuf graph definition
            model_name - prefix to assign to loaded graph node names

        Returns:
            graph, graph_def - as per Tensorflow definitions
        """
        with tf.gfile.GFile(filename, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def,
                                input_map=None,
                                return_elements=None,
                                op_dict=None,
                                producer_op_list=None,
                                name=model_name)

        return graph, graph_def

    @staticmethod
    def imtransform(img):
        """Helper: Rescale pixel value ranges to -1 and 1"""
        return np.array(img) / 127.5-1

    @staticmethod
    def iminvtransform(img):
        """Helper: Rescale pixel value ranges to 0 and 1"""
        return (np.array(img) + 1.0) / 2.0

    @staticmethod
    def poissonblending(img1, img2, mask):
        """Helper: interface to external poisson blending"""
        return blending.blend(img1, img2, 1 - mask)

    @staticmethod
    def createWeightedMask(mask, nsize=7):
        """Takes binary weighted mask to create weighted mask as described in 
        paper.

        Arguments:
            mask - binary mask input. numpy float32 array
            nsize - pixel neighbourhood size. default = 7
        """
        ker = np.ones((nsize,nsize), dtype=np.float32)
        ker = ker/np.sum(ker)
        wmask = mask * convolve2d(mask, ker, mode='same', boundary='symm')
        return wmask

    @staticmethod
    def binarizeMask(mask, dtype=np.float32):
        """Helper function, ensures mask is 0/1 or 0/255 and single channel
        If dtype specified as float32 (default), output mask will be 0, 1
        if required dtype is uint8, output mask will be 0, 255
        """
        assert(np.dtype(dtype) == np.float32 or np.dtype(dtype) == np.uint8)
        bmask = np.array(mask, dtype=np.float32)
        bmask[bmask>0] = 1.0
        bmask[bmask<=0] = 0
        if dtype == np.uint8:
            bmask = np.array(bmask*255, dtype=np.uint8)
        return bmask
    
    @staticmethod
    def create3ChannelMask(mask):
        """Helper function, repeats single channel mask to 3 channels"""
        assert(len(mask.shape)==2)
        return np.repeat(mask[:,:,np.newaxis], 3, axis=2)

4. 工具函数poissonblending

import numpy as np
import scipy.sparse
import PIL.Image
import pyamg

# pre-process the mask array so that uint64 types from opencv.imread can be adapted
def prepare_mask(mask):
    if type(mask[0][0]) is np.ndarray:
        result = np.ndarray((mask.shape[0], mask.shape[1]), dtype=np.uint8)
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                if sum(mask[i][j]) > 0:
                    result[i][j] = 1
                else:
                    result[i][j] = 0
        mask = result
    return mask

def blend(img_target, img_source, img_mask, offset=(0, 0)):
    # compute regions to be blended
    region_source = (
            max(-offset[0], 0),
            max(-offset[1], 0),
            min(img_target.shape[0]-offset[0], img_source.shape[0]),
            min(img_target.shape[1]-offset[1], img_source.shape[1]))
    region_target = (
            max(offset[0], 0),
            max(offset[1], 0),
            min(img_target.shape[0], img_source.shape[0]+offset[0]),
            min(img_target.shape[1], img_source.shape[1]+offset[1]))
    region_size = (region_source[2]-region_source[0], region_source[3]-region_source[1])

    # clip and normalize mask image
    img_mask = img_mask[region_source[0]:region_source[2], region_source[1]:region_source[3]]
    img_mask = prepare_mask(img_mask)
    img_mask[img_mask==0] = False
    img_mask[img_mask!=False] = True

    # create coefficient matrix
    A = scipy.sparse.identity(np.prod(region_size), format='lil')
    for y in range(region_size[0]):
        for x in range(region_size[1]):
            if img_mask[y,x]:
                index = x+y*region_size[1]
                A[index, index] = 4
                if index+1 < np.prod(region_size):
                    A[index, index+1] = -1
                if index-1 >= 0:
                    A[index, index-1] = -1
                if index+region_size[1] < np.prod(region_size):
                    A[index, index+region_size[1]] = -1
                if index-region_size[1] >= 0:
                    A[index, index-region_size[1]] = -1
    A = A.tocsr()
    
    # create poisson matrix for b
    P = pyamg.gallery.poisson(img_mask.shape)

    # for each layer (ex. RGB)
    for num_layer in range(img_target.shape[2]):
        # get subimages
        t = img_target[region_target[0]:region_target[2],region_target[1]:region_target[3],num_layer]
        s = img_source[region_source[0]:region_source[2], region_source[1]:region_source[3],num_layer]
        t = t.flatten()
        s = s.flatten()

        # create b
        b = P * s
        for y in range(region_size[0]):
            for x in range(region_size[1]):
                if not img_mask[y,x]:
                    index = x+y*region_size[1]
                    b[index] = t[index]

        # solve Ax = b
        x = pyamg.solve(A,b,verb=False,tol=1e-10)

        # assign x to target image
        x = np.reshape(x, region_size)
        x[x>255] = 255
        x[x<0] = 0
        x = np.array(x, img_target.dtype)
        img_target[region_target[0]:region_target[2],region_target[1]:region_target[3],num_layer] = x

    return img_target


def test():
    img_mask = np.asarray(PIL.Image.open('./testimages/test1_mask.png'))
    img_mask.flags.writeable = True
    img_source = np.asarray(PIL.Image.open('./testimages/test1_src.png'))
    img_source.flags.writeable = True
    img_target = np.asarray(PIL.Image.open('./testimages/test1_target.png'))
    img_target.flags.writeable = True
    img_ret = blend(img_target, img_source, img_mask, offset=(40,-30))
    img_ret = PIL.Image.fromarray(np.uint8(img_ret))
    img_ret.save('./testimages/test1_ret.png')


if __name__ == '__main__':
    test()

详细项目参考moodoki/semantic_image_inpainting

你可能感兴趣的:(GAN,深度学习paper,Adversarial,Network,paper_GAN)