LLNet模型实现——模型训练(完结)

# Ref: LLNet: Deep Autoencoders for Low-light Image Enhancement
#
# Author: HSW
# Date: 2018-05-11 
#

from prepare_data import * 
from LLNet        import * 

# 训练样本/测试样本的个数
TRAIN_NUM_SAMPLES = 14584
TEST_NUM_SAMPLES  = 14584

def read_batch_data(batch_size, root_dir, split="training"):
    ''' read batch data '''
    train_startIdx = 0
    test_startIdx = 0;

    readObj = LLNet_Data(root_dir, split)
    
    while train_startIdx < TRAIN_NUM_SAMPLES:
        batch_data = []
        batch_label = []

        idx = 0

        while idx < batch_size:
        
            data, label = readObj.read_interface(train_startIdx)
            
            # print("data = {}".format(data))
            
            # print("label = {}".format(label))
            
            train_startIdx += 1
            
            if (data is None) or (label is None):
            	continue
            else:
            	batch_data.append(data)
            	batch_label.append(label)
            	idx += 1
		
        yield np.array(batch_data, dtype = np.float32), np.array(batch_label, dtype=np.float32) 
    
    


def train_pretrain(batch_size, root_dir,  beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epochs = 1001):
    ''' train pre-train '''
    model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
    model.build_graph_pretrain()
    
    for epoch in range(epochs):
        avg_loss = 0
        idx = 1
        for (batch_data, batch_label) in  read_batch_data(batch_size, root_dir, split):

            pretrain_loss = model.run_fitting_pretrain(batch_data, batch_label)

            # print("pretrain_loss = {}".format(pretrain_loss))
            
            sum_loss = 0
            for i in range(batch_size): 
            	sum_loss += pretrain_loss[i]
            
            avg_loss += sum_loss / TRAIN_NUM_SAMPLES * batch_size
            
            print("current avg_loss(Use Sample Count = {}): {}".format(idx * batch_size, avg_loss))
            
            idx += 1
     		
        if epoch > 0 and epoch % 10 == 0:
            print("Epoch: ", "%04d" %(epoch + 1), "pretrain_loss: ", "{:.9f}".format(avg_loss))
            model.saver.save(model.sess, "LLNet_Model_", str(epoch), ".ckpt")
    

        

def train_finetune(batch_size, root_dir,  beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epoch_th = 200, epochs = 1001):
    
    model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
    model.build_graph_finetune()
    
    for epoch in range(epochs):
        avg_loss = 0
        idx = 1
        for (batch_data, batch_label) in  read_batch_data(batch_size, root_dir, split):

            if epoch < epoch_th: 
                finetune_loss = model.run_fitting_finetune_first(batch_data, batch_label)
            else:
                finetune_loss = model.run_fitting_finetune_last(batch_data, batch_label)
            
            sum_loss = 0
            
            for i in range(batch_size):
            	sum_loss += finetune_loss[i]
            	
            avg_loss += sum_loss / TEST_NUM_SAMPLES * batch_size
            
            print("current avg_loss(Use Sample Count = {}): {}".format(idx * batch_size, avg_loss))
            
            idx += 1

        if epoch > 0 and epoch % 10 == 0:
            print("Epoch: ", "%04d" %(epoch + 1), "finetune_loss: ", "{:.9f}".format(avg_loss))
            model.saver.save(model.sess, "LLNet_Model_", str(epoch), ".ckpt")


if __name__ == '__main__':

    train_batch_size = 1
    train_root_dir   = "/home/hsw/LLNet"
    beta_pretrain    = 1000; 
    lambda_pretrain  = 1000; 
    lambda_finetune  = 1000; 
    train_split      = "training"
    
    # pre-train 
    train_pretrain(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)

	# finetune 
    # train_finetune(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)

    
    

您也可以在这里下载我自己处理的样本:

https://download.csdn.net/download/hit1524468/10434217(修改前的代码)

注意:在Ubuntu运行时,保存checkpoint时会出现错误,所以,进行了修改(注意LLNet模型定义部分也进行了修改)

# Ref: LLNet: Deep Autoencoders for Low-light Image Enhancement
#
# Author: HSW
# Date: 2018-05-11 
#

from prepare_data import * 
from LLNet        import * 

# 训练样本/测试样本的个数
TRAIN_NUM_SAMPLES = 14584
TEST_NUM_SAMPLES  = 14584

def read_batch_data(batch_size, root_dir, split="training"):
    ''' read batch data '''
    train_startIdx = 0
    test_startIdx = 0;

    readObj = LLNet_Data(root_dir, split)
    
    while train_startIdx < TRAIN_NUM_SAMPLES:
        batch_data = []
        batch_label = []

        idx = 0

        while idx < batch_size:
        
            data, label = readObj.read_interface(train_startIdx)
            
            # print("data = {}".format(data))
            
            # print("label = {}".format(label))
            
            train_startIdx += 1
            
            if (data is None) or (label is None):
            	continue
            else:
            	batch_data.append(data)
            	batch_label.append(label)
            	idx += 1
		
        yield np.array(batch_data, dtype = np.float32), np.array(batch_label, dtype=np.float32) 
    
    


def train_pretrain(batch_size, root_dir,  beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epochs = 1001):
    ''' train pre-train '''
    model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
    model.build_graph_pretrain()
    
    for epoch in range(epochs):
        avg_loss = 0
        idx = 1
        for (batch_data, batch_label) in  read_batch_data(batch_size, root_dir, split):

            pretrain_loss = model.run_fitting_pretrain(batch_data, batch_label)

            # print("pretrain_loss = {}".format(pretrain_loss))
            
            sum_loss = 0
            for i in range(batch_size): 
            	sum_loss += pretrain_loss[i]
            
            avg_loss += sum_loss / TRAIN_NUM_SAMPLES * batch_size
            
            print("current avg_loss(Epoch = {}, Use Sample Count = {}): {}".format(epoch + 1, idx * batch_size, avg_loss))
            
            idx += 1
     		
        if  epoch > 1 and epoch % 50 == 0:
        	print("Epoch: {}, avg_loss: {}".format(epoch + 1, avg_loss))
        	save_PATH = "./" + "LLNet_Model_" + str(epoch) + ".ckpt"
        	model.saver.save(model.sess, save_PATH)
    

        

def train_finetune(batch_size, root_dir,  beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epoch_th = 200, epochs = 1001):
    
    model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
    model.build_graph_finetune()
    
    for epoch in range(epochs):
        avg_loss = 0
        idx = 1
        for (batch_data, batch_label) in  read_batch_data(batch_size, root_dir, split):

            if epoch < epoch_th: 
                finetune_loss = model.run_fitting_finetune_first(batch_data, batch_label)
            else:
                finetune_loss = model.run_fitting_finetune_last(batch_data, batch_label)
            
            sum_loss = 0
            
            for i in range(batch_size):
            	sum_loss += finetune_loss[i]
            	
            avg_loss += sum_loss / TEST_NUM_SAMPLES * batch_size
            
            print("current avg_loss(Epoch = {}, Use Sample Count = {}): {}".format(epoch + 1, idx * batch_size, avg_loss))
            
            idx += 1

        if epoch > 0 and epoch % 50 == 0:
            print("Epoch: {%4d}, avg_loss: {:.9f}".format(epoch + 1, avg_loss))
            save_PATH = "./" + "LLNet_Model_" + str(epoch) + ".ckpt"
            model.saver.save(model.sess, save_PATH)


if __name__ == '__main__':

    train_batch_size = 1
    train_root_dir   = "/home/hsw/LLNet"
    beta_pretrain    = 1000; 
    lambda_pretrain  = 1000; 
    lambda_finetune  = 1000; 
    train_split      = "training"
    
    # pre-train 
    train_pretrain(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)

	# finetune 
    # train_finetune(batch_size = train_batch_size, root_dir = train_root_dir, beta_pretrain = beta_pretrain, lambda_pretrain = lambda_pretrain, lambda_finetune = lambda_finetune, split = train_split)

    
    


你可能感兴趣的:(算法实现,机器学习,tensorflow,深度学习)