tf2自定义损失函数测试

main.py

import tensorflow as tf
from custom_loss import focal_loss


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=focal_loss(), # 自定义损失函数
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
print()
model.evaluate(x_test,  y_test, verbose=2)

custom_loss

from tensorflow.keras.losses import Loss
from tensorflow.keras.losses import binary_crossentropy
import tensorflow as tf


class focal_loss(Loss):
    def __init__(self, alpha=0.25, gamma=2,**kwargs):
        super(focal_loss,self).__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha
    
    def call(self,y_true,y_pred):
        # y_true转成和y_pred一样的shape
        y_true = tf.squeeze(tf.one_hot(y_true,depth=10))
        BCE = binary_crossentropy(y_true,y_pred)
        pt = tf.math.exp(-BCE)

        F_loss = self.alpha * (1-pt)**self.gamma * BCE
        loss = tf.reduce_mean(F_loss)
        return loss


class MeanSquaredError(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        return tf.reduce_mean(tf.square(y_pred - y_true))


Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0011 - accuracy: 0.8727
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 3.9561e-04 - accuracy: 0.9392  
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 2.8197e-04 - accuracy: 0.9538  
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 2.2211e-04 - accuracy: 0.9626  
Epoch 5/5
1875/1875 [==============================] - 3s 2ms/step - loss: 1.8357e-04 - accuracy: 0.9680  

313/313 - 0s - loss: 1.8853e-04 - accuracy: 0.9706 - 414ms/epoch - 1ms/step

你可能感兴趣的:(tensorflow2,tensorflow,深度学习,keras)