# Ref: LLNet: Deep Autoencoders for Low-light Image Enhancement
#
# Author: HSW
# Date: 2018-05-11
#
from prepare_data import *
from LLNet import *
# 训练样本/测试样本的个数
TRAIN_NUM_SAMPLES = 14584
TEST_NUM_SAMPLES = 14584
def read_batch_data(batch_size, root_dir, split="training"):
''' read batch data '''
train_startIdx = 0
test_startIdx = 0;
readObj = LLNet_Data(root_dir, split)
while train_startIdx < TRAIN_NUM_SAMPLES:
batch_data = []
batch_label = []
idx = 0
while idx < batch_size:
data, label = readObj.read_interface(train_startIdx)
# print("data = {}".format(data))
# print("label = {}".format(label))
train_startIdx += 1
if (data is None) or (label is None):
continue
else:
batch_data.append(data)
batch_label.append(label)
idx += 1
yield np.array(batch_data, dtype = np.float32), np.array(batch_label, dtype=np.float32)
def train_pretrain(batch_size, root_dir, beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epochs = 1001):
''' train pre-train '''
model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
model.build_graph_pretrain()
for epoch in range(epochs):
avg_loss = 0
idx = 1
for (batch_data, batch_label) in read_batch_data(batch_size, root_dir, split):
pretrain_loss = model.run_fitting_pretrain(batch_data, batch_label)
# print("pretrain_loss = {}".format(pretrain_loss))
sum_loss = 0
for i in range(batch_size):
sum_loss += pretrain_loss[i]
avg_loss += sum_loss / TRAIN_NUM_SAMPLES * batch_size
print("current avg_loss(Use Sample Count = {}): {}".format(idx * batch_size, avg_loss))
idx += 1
if epoch > 0 and epoch % 10 == 0:
print("Epoch: ", "%04d" %(epoch + 1), "pretrain_loss: ", "{:.9f}".format(avg_loss))
model.saver.save(model.sess, "LLNet_Model_", str(epoch), ".ckpt")
def train_finetune(batch_size, root_dir, beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epoch_th = 200, epochs = 1001):
model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
model.build_graph_finetune()
for epoch in range(epochs):
avg_loss = 0
idx = 1
for (batch_data, batch_label) in read_batch_data(batch_size, root_dir, split):
if epoch < epoch_th:
finetune_loss = model.run_fitting_finetune_first(batch_data, batch_label)
else:
finetune_loss = model.run_fitting_finetune_last(batch_data, batch_label)
sum_loss = 0
for i in range(batch_size):
sum_loss += finetune_loss[i]
avg_loss += sum_loss / TEST_NUM_SAMPLES * batch_size
print("current avg_loss(Use Sample Count = {}): {}".format(idx * batch_size, avg_loss))
idx += 1
if epoch > 0 and epoch % 10 == 0:
print("Epoch: ", "%04d" %(epoch + 1), "finetune_loss: ", "{:.9f}".format(avg_loss))
model.saver.save(model.sess, "LLNet_Model_", str(epoch), ".ckpt")
if __name__ == '__main__':
train_batch_size = 1
train_root_dir = "/home/hsw/LLNet"
beta_pretrain = 1000;
lambda_pretrain = 1000;
lambda_finetune = 1000;
train_split = "training"
# pre-train
train_pretrain(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)
# finetune
# train_finetune(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)
您也可以在这里下载我自己处理的样本:
https://download.csdn.net/download/hit1524468/10434217(修改前的代码)
注意:在Ubuntu运行时,保存checkpoint时会出现错误,所以,进行了修改(注意LLNet模型定义部分也进行了修改)
# Ref: LLNet: Deep Autoencoders for Low-light Image Enhancement
#
# Author: HSW
# Date: 2018-05-11
#
from prepare_data import *
from LLNet import *
# 训练样本/测试样本的个数
TRAIN_NUM_SAMPLES = 14584
TEST_NUM_SAMPLES = 14584
def read_batch_data(batch_size, root_dir, split="training"):
''' read batch data '''
train_startIdx = 0
test_startIdx = 0;
readObj = LLNet_Data(root_dir, split)
while train_startIdx < TRAIN_NUM_SAMPLES:
batch_data = []
batch_label = []
idx = 0
while idx < batch_size:
data, label = readObj.read_interface(train_startIdx)
# print("data = {}".format(data))
# print("label = {}".format(label))
train_startIdx += 1
if (data is None) or (label is None):
continue
else:
batch_data.append(data)
batch_label.append(label)
idx += 1
yield np.array(batch_data, dtype = np.float32), np.array(batch_label, dtype=np.float32)
def train_pretrain(batch_size, root_dir, beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epochs = 1001):
''' train pre-train '''
model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
model.build_graph_pretrain()
for epoch in range(epochs):
avg_loss = 0
idx = 1
for (batch_data, batch_label) in read_batch_data(batch_size, root_dir, split):
pretrain_loss = model.run_fitting_pretrain(batch_data, batch_label)
# print("pretrain_loss = {}".format(pretrain_loss))
sum_loss = 0
for i in range(batch_size):
sum_loss += pretrain_loss[i]
avg_loss += sum_loss / TRAIN_NUM_SAMPLES * batch_size
print("current avg_loss(Epoch = {}, Use Sample Count = {}): {}".format(epoch + 1, idx * batch_size, avg_loss))
idx += 1
if epoch > 1 and epoch % 50 == 0:
print("Epoch: {}, avg_loss: {}".format(epoch + 1, avg_loss))
save_PATH = "./" + "LLNet_Model_" + str(epoch) + ".ckpt"
model.saver.save(model.sess, save_PATH)
def train_finetune(batch_size, root_dir, beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epoch_th = 200, epochs = 1001):
model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
model.build_graph_finetune()
for epoch in range(epochs):
avg_loss = 0
idx = 1
for (batch_data, batch_label) in read_batch_data(batch_size, root_dir, split):
if epoch < epoch_th:
finetune_loss = model.run_fitting_finetune_first(batch_data, batch_label)
else:
finetune_loss = model.run_fitting_finetune_last(batch_data, batch_label)
sum_loss = 0
for i in range(batch_size):
sum_loss += finetune_loss[i]
avg_loss += sum_loss / TEST_NUM_SAMPLES * batch_size
print("current avg_loss(Epoch = {}, Use Sample Count = {}): {}".format(epoch + 1, idx * batch_size, avg_loss))
idx += 1
if epoch > 0 and epoch % 50 == 0:
print("Epoch: {%4d}, avg_loss: {:.9f}".format(epoch + 1, avg_loss))
save_PATH = "./" + "LLNet_Model_" + str(epoch) + ".ckpt"
model.saver.save(model.sess, save_PATH)
if __name__ == '__main__':
train_batch_size = 1
train_root_dir = "/home/hsw/LLNet"
beta_pretrain = 1000;
lambda_pretrain = 1000;
lambda_finetune = 1000;
train_split = "training"
# pre-train
train_pretrain(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)
# finetune
# train_finetune(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)