PixelShuffler实现上采样的原理

一,普通的上采样采用的临近像素填充算法,主要考虑空间因素,没有考虑channel因素,上采样的特征图人为修改痕迹明显,图像分割与GAN生成图像中效果不好。

二,PixelShuffler结合channel维度的信息来填充像素,可以实现超高分辨率,生成的图像更逼真,更完美。

三,PixelShuffler实际原理
PixelShuffler实现上采样的原理_第1张图片
pixelshuffle算法的实现流程如上图,其实现的功能是:将一个H × W的低分辨率输入图像(Low Resolution),通过Sub-pixel操作将其变为rH x rW的高分辨率图像(High Resolution)。
但是其实现过程不是直接通过插值等方式产生这个高分辨率图像,而是通过卷积先得到r2个通道的特征图(特征图大小和输入低分辨率图像一致),然后通过周期筛选(periodic shuffing)的方法得到这个高分辨率的图像,其中r为上采样因子(upscaling factor),也就是图像的扩大倍率。

四,keras实现代码:

# PixelShuffler layer for Keras# by t-ae# https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981
from keras.utils import conv_utilsfrom keras.engine.topology import Layerimport keras.backend as K
class PixelShuffler(Layer):
    def __init__(self, size=(2, 2), data_format=None, **kwargs):
            super(PixelShuffler, self).__init__(**kwargs)        self.data_format = K.image_data_format()
            self.size = conv_utils.normalize_tuple(size, 2, 'size')
    def call(self, inputs):
        input_shape = K.int_shape(inputs)
                if len(input_shape) != 4:            raise ValueError('Inputs should have rank ' +                             str(4) +                             '; Received input shape:', str(input_shape))
        if self.data_format == 'channels_first':            batch_size, c, h, w = input_shape
                    if batch_size is None:                batch_size = -1            rh, rw = self.size            oh, ow = h * rh, w * rw            oc = c // (rh * rw)
            out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w))            out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2))            out = K.reshape(out, (batch_size, oc, oh, ow))            return out
        elif self.data_format == 'channels_last':            batch_size, h, w, c = input_shape
                    if batch_size is None:                batch_size = -1            rh, rw = self.size            oh, ow = h * rh, w * rw            oc = c // (rh * rw)
            out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc))            out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5))            out = K.reshape(out, (batch_size, oh, ow, oc))            return out
    def compute_output_shape(self, input_shape):
        if len(input_shape) != 4:            raise ValueError('Inputs should have rank ' +                             str(4) +                             '; Received input shape:', str(input_shape))
        if self.data_format == 'channels_first':            height = input_shape[2] * self.size[0] if input_shape[2] is not None else None            width = input_shape[3] * self.size[1] if input_shape[3] is not None else None            channels = input_shape[1] // self.size[0] // self.size[1]
            if channels * self.size[0] * self.size[1] != input_shape[1]:                raise ValueError('channels of input and size are incompatible')
            return (input_shape[0],                    channels,                    height,                    width)
        elif self.data_format == 'channels_last':            height = input_shape[1] * self.size[0] if input_shape[1] is not None else None            width = input_shape[2] * self.size[1] if input_shape[2] is not None else None            channels = input_shape[3] // self.size[0] // self.size[1]
            if channels * self.size[0] * self.size[1] != input_shape[3]:                raise ValueError('channels of input and size are incompatible')
            return (input_shape[0],                    height,                    width,                    channels)
    def get_config(self):        config = {'size': self.size,                  'data_format': self.data_format}        base_config = super(PixelShuffler, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
from .pixel_shuffler import PixelShuffler
def upscale_ps(input_tensor, f, use_norm=False, w_l2=w_l2, norm='none'):
           x = input_tensor
           x = Conv2D(f*4, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), 
                      kernel_initializer=icnr_keras, padding='same')(x)
           x = LeakyReLU(0.2)(x)
           x = normalization(x, norm, f) if use_norm else x
           x = PixelShuffler()(x)
           return x

你可能感兴趣的:(AI,keras,图像的处理)