from keras.models import Model
from keras.layers import Input,Conv2D,Reshape,GlobalAvgPool2D
from keras.layers import Lambda
import tensorflow as tf
import keras
from keras import backend
from keras.layers import multiply
class Mutiply(keras.layers.Layer):
""" Keras layer for Mutiply a Tensor to be the same shape as another Tensor.
"""
def __init__(self,**kwargs):
super(Mutiply,self).__init__(**kwargs)
def call(self, inputs, **kwargs):
source, target = inputs
target_shape = keras.backend.shape(target)
source = tf.tile(source,[1,1,1,target_shape[3]])
return tf.multiply(source,target)
def compute_output_shape(self, input_shape):
return (input_shape[1][0],) + input_shape[1][1:3] + (input_shape[1][-1],)
class UpsampleLike(keras.layers.Layer):
""" Keras layer for upsampling a Tensor to be the same shape as another Tensor.
"""
def call(self, inputs, **kwargs):
source, target = inputs
target_shape = keras.backend.shape(target)
if keras.backend.image_data_format() == 'channels_first':
source = backend.transpose(source, (0, 2, 3, 1))
output = tf.image.resize_nearest_neighbor(source, (target_shape[2], target_shape[3]))
#output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest')
output = backend.transpose(output, (0, 3, 1, 2))
return output
else:
#return backend.resize_images(source, (target_shape[1], target_shape[2]), method='bilinear')
return tf.image.resize_bilinear(source, (target_shape[1], target_shape[2]))
def compute_output_shape(self, input_shape):
if keras.backend.image_data_format() == 'channels_first':
return (input_shape[0][0], input_shape[0][1]) + input_shape[1][2:4]
else:
return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],)
def Interp(x, shape):
''' interpolation '''
from keras.backend import tf as ktf
new_height, new_width = shape
resized = ktf.image.resize_images(
x,
[int(new_height), int(new_width)],
align_corners=True)
return resized
if __name__ == '__main__':
x = Input(shape=(12,12,3))
normed = Lambda(lambda z: z / 127.5 - 1., # Convert input feature range to [-1,1]
output_shape=(12, 12, 3),
name='lambda1')(x)
global_feat = Lambda(
Interp,
arguments={'shape': (24,24)})(normed)
global_avg = GlobalAvgPool2D()(global_feat)
alpha = Reshape(target_shape=(1,1,3))(global_avg)
final = multiply([alpha,global_feat])
model = Model(x,final)
model.summary()