tf.extract_image_patches以及pytorch的extract_patches

前言

由于最近一直在研究上下文注意力机制,当对特征图FM提取patch,需要用到函数extract_image_patches,特此记录一下。温馨提示:提笔计算一下,更容易理解。

1、 tensorflow的extract_image_patches函数提取patch

参考1 https://blog.csdn.net/qq_38675570/article/details/81239774
参考2 http://www.voidcn.com/article/p-svnrphay-btg.html

def extract_image_patches(  
    images,      # 输入的特征图[b,h,w,c]
    ksizes=None, # patch大小,图像这块一般常用1,3,5。参数为[1, k_size, k_size, 1]
    strides=None,# 步长,[1, stride, stride, 1],滑动的像素
    rates=None,  # 参考扩张卷积,相邻像素填0个数
    padding=None,# “VALID”,每个补丁必须完全包含在图像中;“SAME”,允许补丁不完整
    name=None,
    sizes=None):

这里以特征图bg=fg_in=[8,32,32,192]为例:

# 先从背景区域中提取块
patch1 = tf.extract_image_patches(bg, [1, k_size, k_size, 1], [1, stride, stride, 1], [1, 1, 1, 1], 'VALID')
print("patch1 size",patch1.shape) # [8,30,30,1728]

# 再从前景区域中提取块,添加了padding
patch2 = tf.extract_image_patches(fg_in, [1,k_size,k_size,1], [1,1,1,1], [1,1,1,1], 'SAME')
print("patch2 size", patch2.shape) # [8,32,32,1728]

输出维度分析:
h=32,stride=1,k_size=3。当padding=0时,计算出的块个数为c=900;当padding=1时,块的个数为32*32=1024。
提出的patch的维度为1728 = 192 * 3 * 3,也就是输出的最后一维是c_in * k_size * k_size。在对patch计算余弦相似度时,处理的就是这一部分。

kn = int((k_size - 1) / 2)
c = 0 # 计算块的个数
 for p in range(kn, h - kn, stride):
     for q in range(kn, w - kn, stride):
         c += 1

2、pytorch的ford和unfold

参考1 https://blog.csdn.net/LoseInVain/article/details/88139435
上面这条链接将网络层的torch.nn.unfold()函数讲得很透彻,看明白之后大概就理解了。

# extract patches
def extract_patches(x, kernel=3, stride=1):
    if kernel != 1:
        x = nn.ZeroPad2d(1)(x)
    x = x.permute(0, 2, 3, 1)
    all_patches = x.unfold(1, kernel, stride).unfold(2, kernel, stride)
    return all_patches

调用函数,输入的维度x2=[8,192,32,32],ksize=3,stride=1,提取patch后输出的维度为[8,32,32,192,3,3],和tensorflow的用法差不多,将tensor进行view一下可以达到一样的效果。

w = extract_patches(x2, kernel=self.ksize, stride=self.stride)
print("w",w.size()) # [8,32,32,192,3,3]

下面给出2019 CVPR PEPSI : Fast Image Inpainting With Parallel Decoding Network论文中注意力模块,有做图像修复的朋友可以瞅一瞅,我自己修改了一些添加了一些代码。
论文:地址
翻译:https://blog.csdn.net/baidu_33256174/article/details/102570361
code:https://github.com/Forty-lock/PEPSI

from __future__ import division
# from ops import *
import tensorflow as tf
from tensorflow.keras.layers import Conv2D,ELU


def softmax(input):

    k = tf.exp(input - 3)
    k = tf.reduce_sum(k, 3, True)
    # k = k - num * tf.ones_like(k)

    ouput = tf.exp(input - 3) / k

    return ouput

def reduce_var(x, axis=None, keepdims=False):
    """Variance of a tensor, alongside the specified axis.

    # Arguments
        x: A tensor or variable.
        axis: An integer, the axis to compute the variance.
        keepdims: A boolean, whether to keep the dimensions or not.
            If `keepdims` is `False`, the rank of the tensor is reduced
            by 1. If `keepdims` is `True`,
            the reduced dimension is retained with length 1.

    # Returns
        A tensor with the variance of elements of `x`.
    """
    m = tf.reduce_mean(x, axis=axis, keepdims=True)
    devs_squared = tf.square(x - m)
    return tf.reduce_mean(devs_squared, axis=axis, keepdims=keepdims)

def reduce_std(x, axis=None, keepdims=False):
    """Standard deviation of a tensor, alongside the specified axis.

    # Arguments
        x: A tensor or variable.
        axis: An integer, the axis to compute the standard deviation.
        keepdims: A boolean, whether to keep the dimensions or not.
            If `keepdims` is `False`, the rank of the tensor is reduced
            by 1. If `keepdims` is `True`,
            the reduced dimension is retained with length 1.

    # Returns
        A tensor with the standard deviation of elements of `x`.
    """
    return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims))

def contextual_block(bg_in, fg_in, mask, k_size, lamda, name, stride=1):
    with tf.variable_scope(name):
        b, h, w, dims = [i.value for i in bg_in.get_shape()]
        temp = tf.image.resize_nearest_neighbor(mask, (h, w)) # 进行插值扩张到(h,w)
        temp = tf.expand_dims(temp[:, :, :, 0], 3) # b 128 128 1
        mask_r = tf.tile(temp, [1, 1, 1, dims]) # b 128 128 128 ,复制128次
        bg = bg_in * mask_r

        kn = int((k_size - 1) / 2)
        c = 0 # 计算块的个数
        for p in range(kn, h - kn, stride):
            for q in range(kn, w - kn, stride):
                c += 1
        # 先从背景区域中提取块
        patch1 = tf.extract_image_patches(bg, [1, k_size, k_size, 1], [1, stride, stride, 1], [1, 1, 1, 1], 'VALID')
        print("patch1 size",patch1.shape) # [8,30,30,1728]
        # 推算一下
        patch1 = tf.reshape(patch1, (b, 1, c, k_size*k_size*dims)) # [b,out_rows,out_cols,k_size*k_size*c]
        patch1 = tf.reshape(patch1, (b, 1, 1, c, k_size * k_size * dims))
        patch1 = tf.transpose(patch1, [0, 1, 2, 4, 3]) # 扩展成5维
        print("patch1 size", patch1.shape) # [8,1,1,1728,900]
        # 再从前景区域中提取块,添加了padding
        patch2 = tf.extract_image_patches(fg_in, [1,k_size,k_size,1], [1,1,1,1], [1,1,1,1], 'SAME')
        print("patch2 size", patch2.shape) # [8,32,32,1728]
        ACL = []

        for ib in range(b): # 一张一张图的计算

            k1 = patch1[ib, :, :, :, :] # 一张图像的背景
            print("k1",k1.shape) # [1,1,1728,900]
            k1d = tf.reduce_sum(tf.square(k1), axis=2)
            print("k1d",k1d.shape)  # [1,1,900]
            k2 = tf.reshape(k1, (k_size, k_size, dims, c))
            print("k2",k2.shape) # (3, 3, 192, 900)

            ww = patch2[ib, :, :, :]
            print("ww", ww.shape) # [32,32,1728]
            wwd = tf.reduce_sum(tf.square(ww), axis=2, keepdims=True)
            print("wwd",wwd.shape) # [32,32,1]
            ft = tf.expand_dims(ww, 0) # [1,32,32,1728]

            CS = tf.nn.conv2d(ft, k1, strides=[1, 1, 1, 1], padding='SAME')
            print("CS",CS.shape) # [1,32,32,900]

            tt = k1d + wwd
            print("tt",tt.shape) # [1,32,32,900]
            # 在axis=0处增加一个为1的维度
            DS1 = tf.expand_dims(tt, 0) - 2 * CS
            # 规范化处理
            DS2 = (DS1 - tf.reduce_mean(DS1, 3, True)) / reduce_std(DS1, 3, True)
            print('DS2',DS2.shape) # [1,32,32,900]
            DS2 = -1 * tf.nn.tanh(DS2)

            CA = softmax(lamda * DS2) # 计算得到相似性分数
            print("CA",CA.shape) # [1,32,32,900]
            # K2为卷积核
            ACLt = tf.nn.conv2d_transpose(CA, k2, output_shape=[1, h, w, dims], strides=[1, 1, 1, 1], padding='SAME')
            print("ACLt",ACLt.shape) # (1, 32, 32, 192)
            ACLt = ACLt / (k_size ** 2)

            if ib == 0:
                ACL = ACLt
            else:
                ACL = tf.concat((ACL, ACLt), 0)

        ACL = bg + ACL * (1.0 - mask_r) # mask,0表示缺失

        con1 = tf.concat([bg_in, ACL], 3) # 整合背景和重构的前景
        ACL2 = Conv2D(dims, [1, 1], strides=[1, 1], padding='VALID')(con1)
        ACL2 = tf.nn.elu(ACL2)

        return ACL2

def ca_test(args):
    from PIL import Image
    import matplotlib.pyplot as plt
    import os
    import numpy as np
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    def float_to_uint8(img):
        img = img * 255
        return img.astype('uint8')

    def default_loader(path):
        with open(path, 'rb') as f:
            image = Image.open(f).convert('RGB')
        return image

    rate = 2
    stride = 2
    grid = rate * stride

    # b = default_loader(args.imageA)
    # w, h = b.size
    # # b = b.resize((w//grid*grid//2, h//grid*grid//2), Image.ANTIALIAS)
    # b = b.resize((w // (grid * grid*2), h // (grid * grid*2)), Image.ANTIALIAS)
    # print('Size of imageA: {}'.format(b.size))
    #
    # f = default_loader(args.imageB)
    # w, h = f.size
    # f = f.resize((w // (grid *grid*2), h // (grid *grid*2)), Image.ANTIALIAS)
    # # print('Size of imageB: {}'.format(f.size))

    b = np.zeros([8,32,32,192],dtype=np.float32)
    print(type(b))

    mask = default_loader(args.mask)
    print("mask",mask.size)
    w, h = mask.size
    mask = mask.resize((w // (grid*2), h // (grid*2)), Image.ANTIALIAS)
    mask = mask.convert("L")
    print("Size of mask:{}".format(mask.size))

    # variable
    batch_size, Height, Width = 8,32,32

    bg_var = tf.placeholder(tf.float32, [batch_size, Height, Width, 192])
    fg_var = tf.placeholder(tf.float32, [batch_size, Height, Width, 192])
    mask_var = tf.placeholder(tf.float32, [batch_size, Height, Width, 1])

    y = contextual_block(bg_in=bg_var, fg_in=fg_var, mask=mask_var, k_size=3, lamda=80, name='CA')

    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.9  # 限制GPU使用率为40%
    config.gpu_options.allow_growth = True  # 动态申请显存

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        # b = tf.expand_dims(b, axis=0).eval()
        # f = tf.expand_dims(f, axis=0).eval()
        mask = tf.expand_dims(mask,axis=2)
        mask = tf.expand_dims(mask, axis=0)
        mask = tf.tile(mask, [8, 1, 1, 1]).eval()
        print(type(mask))
        print("b:",b.shape,"f:",b.shape,"mask:",mask.shape)

        y = sess.run([y],feed_dict={bg_var:b,fg_var:b,mask_var:mask})

    # plt.imshow(y)
    # plt.show()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--imageA', default='../image/000147.jpg.jpg', type=str,
                        help='Image A as background patches to reconstruct image B.')
    parser.add_argument('--imageB', default='../image/000201.jpg.jpg', type=str,
                        help='Image B is reconstructed with image A.')
    parser.add_argument('--mask', default='../image/mask_0001.jpg', type=str)
    parser.add_argument('--imageOut', default='../image/result.png', type=str,
                        help='Image B is reconstructed with image A.')
    args = parser.parse_args()
    ca_test(args)

你可能感兴趣的:(图像修复,tensorflow,pytorch)