Keras自定义层(常用代码)


from keras.models import Model
from keras.layers import Input,Conv2D,Reshape,GlobalAvgPool2D
from keras.layers import Lambda
import tensorflow as tf
import keras
from keras import backend
from keras.layers import multiply
class Mutiply(keras.layers.Layer):
    """ Keras layer for Mutiply a Tensor to be the same shape as another Tensor.
    """
    def __init__(self,**kwargs):
        super(Mutiply,self).__init__(**kwargs)
    def call(self, inputs, **kwargs):
        source, target = inputs
        target_shape = keras.backend.shape(target)
        source = tf.tile(source,[1,1,1,target_shape[3]])
        return tf.multiply(source,target)

    def compute_output_shape(self, input_shape):
        return (input_shape[1][0],) + input_shape[1][1:3] + (input_shape[1][-1],)
class UpsampleLike(keras.layers.Layer):
    """ Keras layer for upsampling a Tensor to be the same shape as another Tensor.
    """

    def call(self, inputs, **kwargs):
        source, target = inputs
        target_shape = keras.backend.shape(target)
        if keras.backend.image_data_format() == 'channels_first':
            source = backend.transpose(source, (0, 2, 3, 1))
            output = tf.image.resize_nearest_neighbor(source, (target_shape[2], target_shape[3]))
            #output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest')
            output = backend.transpose(output, (0, 3, 1, 2))
            return output
        else:
            #return backend.resize_images(source, (target_shape[1], target_shape[2]), method='bilinear')
            return tf.image.resize_bilinear(source, (target_shape[1], target_shape[2]))

    def compute_output_shape(self, input_shape):
        if keras.backend.image_data_format() == 'channels_first':
            return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4]
        else:
            return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],)
def Interp(x, shape):
    ''' interpolation '''
    from keras.backend import tf as ktf
    new_height, new_width = shape
    resized = ktf.image.resize_images(
            x,
            [int(new_height), int(new_width)],
            align_corners=True)
    return resized
if __name__ == '__main__':
    x = Input(shape=(12,12,3))
    normed = Lambda(lambda z: z / 127.5 - 1.,  # Convert input feature range to [-1,1]
                    output_shape=(12, 12, 3),
                    name='lambda1')(x)
    global_feat = Lambda(
        Interp,
        arguments={'shape': (24,24)})(normed)
    global_avg = GlobalAvgPool2D()(global_feat)
    alpha = Reshape(target_shape=(1,1,3))(global_avg)
    final = multiply([alpha,global_feat])
    model = Model(x,final)
    model.summary()

 

你可能感兴趣的:(keras)