B3大概适合一般的服务器跑,batchsize不要设置太大,很容易超出显存。
这个是用的前面AlexNet数据集,改天放个数据集地址,嗝。
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import math
import EfficientNet_Model
import numpy as np
from Data_Channel import Data_Channel
import random
BatchSize = 50
Resolution = 260
EPOCHS = 3000
Save_n_Epoch = 50
ALL_images = 10000
Save_DIR = "./ModelLog/"
Labels_OH = {"cloudy":np.ones(BatchSize), "haze":np.full((BatchSize,),2),
"rainy":np.full((BatchSize,),3), "snow":np.full((BatchSize,),4),
"sunny":np.full((BatchSize,),5), "thunder":np.zeros(BatchSize)}
Valid_OH = {"cloudy":np.ones(20), "haze":np.full((20,),2),
"rainy":np.full((20,),3), "snow":np.full((20,),4),
"sunny":np.full((20,),5), "thunder":np.zeros(20)}
DC_Dic = {"cloudy":Data_Channel(category="cloudy", pool_size=BatchSize, resolution=Resolution),
"haze":Data_Channel(category="haze", pool_size=BatchSize, resolution=Resolution),
"rainy":Data_Channel(category="rainy", pool_size=BatchSize, resolution=Resolution),
"snow":Data_Channel(category="snow", pool_size=BatchSize, resolution=Resolution),
"sunny":Data_Channel(category="sunny", pool_size=BatchSize, resolution=Resolution),
"thunder":Data_Channel(category="thunder", pool_size=BatchSize, resolution=Resolution)}
DC_list = ["cloudy", "haze", "rainy", "snow", "sunny", "thunder"]
'''def process_features(features, data_augmentation):
image_raw = features['image_raw'].numpy()
image_tensor_list = []
for image in image_raw:
image_tensor = load_and_preprocess_image(image, data_augmentation=data_augmentation)
image_tensor_list.append(image_tensor)
images = tf.stack(image_tensor_list, axis=0)
labels = features['label'].numpy()
return images, labels'''
if __name__ == '__main__':
# GPU settings
gpus = tf.config.list_physical_devices("GPU")
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# create model
model = EfficientNet_Model.efficient_net_b2()
# define loss and optimizer, label must be given by round number!
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.RMSprop()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
# @tf.function
def train_step(image_batch, label_batch):
with tf.GradientTape() as tape:
predictions = model(image_batch, training=True)
loss = loss_object(y_true=label_batch, y_pred=predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))
train_loss.update_state(values=loss)
train_accuracy.update_state(y_true=label_batch, y_pred=predictions)
# @tf.function
def valid_step(image_batch, label_batch):
predictions = model(image_batch, training=False)
v_loss = loss_object(label_batch, predictions)
valid_loss.update_state(values=v_loss)
valid_accuracy.update_state(y_true=label_batch, y_pred=predictions)
# start training
for epoch in range(EPOCHS):
for step in range(round(ALL_images/BatchSize)):
Category = random.choice(DC_list)
Channel_now = DC_Dic[Category]
Channel_now.Renew_dataset()
train_step(Channel_now.RF_pool, Labels_OH[Category])
print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch,
EPOCHS,
step,
5,
train_loss.result().numpy(),
train_accuracy.result().numpy()))
Channel_now.Renew_Valid_ds()
valid_step(Channel_now.Valid_pool, Valid_OH[Category])
print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, "
"valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch,
EPOCHS,
train_loss.result().numpy(),
train_accuracy.result().numpy(),
valid_loss.result().numpy(),
valid_accuracy.result().numpy()))
train_loss.reset_states()
train_accuracy.reset_states()
valid_loss.reset_states()
valid_accuracy.reset_states()
if epoch % Save_n_Epoch == 0:
model.save_weights(filepath=Save_DIR+"epoch-{}".format(epoch), save_format='tf')
# save weights
model.save_weights(filepath=Save_DIR+"model", save_format='tf')