\qquad 我们在使用Tensorflow的时候,有时候自带的激活函数和损失函数不够用,我们就要自己定义自己的函数。下面我给出一种方法,我试验可行,当然我也是参考的官方文档和一些博客。基于tf2.0
\qquad 这是我自己做实验的时候用到的一个损失函数,我需要把输出的图片和标签图片计算SSIM的值,然后用1-SSIM的值作为损失函数值。
def tf_ssim_loss(y_true, y_pred):
total_loss = 1 - tf.image.ssim(y_true, y_pred, max_val=1)
return total_loss
\qquad 这是第一步。
\qquad 这是我实验的时候定义的激活函数。具体每个人按自己需要来做。需要说明的是在tf里面,我们使用的一些像指数函数exp()之类的东西,需要时K.的。因为这样才算是对张量tensor进行计算。因为默认网络里面传递的都是tensor.
\qquad 当然需要在前面import.不过这个问题在python可以自动解决。
from tensorflow.keras import backend as K
def custom_activation_1(x):
cond = K.greater(x, 0)
return K.switch(cond, 1-(1-x**2)*K.exp(-x**2/2), x*K.exp(-x*x/2))
\qquad 完成了上面的步骤之后,基本已经弄好了,但是运行的话,tf识别不了我们自定义的激活函数和损失函数。我么需要在模型建立之前加上这句:
get_custom_objects().update({'custom_activation': Activation(custom_activation_1)})
\qquad 这样就可以使用我们的激活函数了。
\qquad 但是,我们在保存完模型之后,加载模型又会出现问题。应该这样加载模型。
# 保存模型
model.save('./saved/my_model.h5')
print("saved total mdoel.")
model = tf.keras.models.load_model('./saved/my_model.h5',
custom_objects={'tf_ssim_loss': tf_ssim_loss,
'custom_activation_1': Activation(custom_activation_1)})