基于keras的EfficientNet复现 【2 训练模块】

B3大概适合一般的服务器跑,batchsize不要设置太大,很容易超出显存。
这个是用的前面AlexNet数据集,改天放个数据集地址,嗝。

from __future__ import absolute_import, division, print_function
import tensorflow as tf
import math
import EfficientNet_Model
import numpy as np
from Data_Channel import Data_Channel
import random

BatchSize = 50
Resolution = 260
EPOCHS = 3000
Save_n_Epoch = 50
ALL_images = 10000
Save_DIR = "./ModelLog/"

Labels_OH = {"cloudy":np.ones(BatchSize), "haze":np.full((BatchSize,),2),
             "rainy":np.full((BatchSize,),3), "snow":np.full((BatchSize,),4),
             "sunny":np.full((BatchSize,),5), "thunder":np.zeros(BatchSize)}
Valid_OH = {"cloudy":np.ones(20), "haze":np.full((20,),2),
             "rainy":np.full((20,),3), "snow":np.full((20,),4),
             "sunny":np.full((20,),5), "thunder":np.zeros(20)}
DC_Dic = {"cloudy":Data_Channel(category="cloudy", pool_size=BatchSize, resolution=Resolution),
          "haze":Data_Channel(category="haze", pool_size=BatchSize, resolution=Resolution),
          "rainy":Data_Channel(category="rainy", pool_size=BatchSize, resolution=Resolution),
          "snow":Data_Channel(category="snow", pool_size=BatchSize, resolution=Resolution),
          "sunny":Data_Channel(category="sunny", pool_size=BatchSize, resolution=Resolution),
          "thunder":Data_Channel(category="thunder", pool_size=BatchSize, resolution=Resolution)}

DC_list = ["cloudy", "haze", "rainy", "snow", "sunny", "thunder"]


'''def process_features(features, data_augmentation):
    image_raw = features['image_raw'].numpy()
    image_tensor_list = []
    for image in image_raw:
        image_tensor = load_and_preprocess_image(image, data_augmentation=data_augmentation)
        image_tensor_list.append(image_tensor)
    images = tf.stack(image_tensor_list, axis=0)
    labels = features['label'].numpy()

    return images, labels'''


if __name__ == '__main__':
    # GPU settings
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    # create model
    model = EfficientNet_Model.efficient_net_b2()

    # define loss and optimizer, label must be given by round number!
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.RMSprop()

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    valid_loss = tf.keras.metrics.Mean(name='valid_loss')
    valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')

    # @tf.function
    def train_step(image_batch, label_batch):
        with tf.GradientTape() as tape:
            predictions = model(image_batch, training=True)
            loss = loss_object(y_true=label_batch, y_pred=predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))

        train_loss.update_state(values=loss)
        train_accuracy.update_state(y_true=label_batch, y_pred=predictions)

    # @tf.function
    def valid_step(image_batch, label_batch):
        predictions = model(image_batch, training=False)
        v_loss = loss_object(label_batch, predictions)

        valid_loss.update_state(values=v_loss)
        valid_accuracy.update_state(y_true=label_batch, y_pred=predictions)

    # start training
    for epoch in range(EPOCHS):
        for step in range(round(ALL_images/BatchSize)):
            Category = random.choice(DC_list)
            Channel_now = DC_Dic[Category]
            Channel_now.Renew_dataset()
            train_step(Channel_now.RF_pool, Labels_OH[Category])
            print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch,
                                                                                     EPOCHS,
                                                                                     step,
                                                                                     5,
                                                                                     train_loss.result().numpy(),
                                                                                     train_accuracy.result().numpy()))

            Channel_now.Renew_Valid_ds()
            valid_step(Channel_now.Valid_pool, Valid_OH[Category])

            print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, "
                  "valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch,
                                                                      EPOCHS,
                                                                      train_loss.result().numpy(),
                                                                      train_accuracy.result().numpy(),
                                                                      valid_loss.result().numpy(),
                                                                      valid_accuracy.result().numpy()))
            train_loss.reset_states()
            train_accuracy.reset_states()
            valid_loss.reset_states()
            valid_accuracy.reset_states()

        if epoch % Save_n_Epoch == 0:
            model.save_weights(filepath=Save_DIR+"epoch-{}".format(epoch), save_format='tf')


    # save weights
    model.save_weights(filepath=Save_DIR+"model", save_format='tf')

你可能感兴趣的:(深度学习,cv,tensorflow,机器学习)