记录一个深度学习pytorch训练模板的搭建心得(自查用)

首先,要搭建的训练代码是实验室师兄在CVPR2019上发表的已开源文章SRFBN,非常感谢师兄的搭建的深度学习模板。
paper:Feedback Network for Image Super-Resolution
github:https://github.com/Paper99/SRFBN_CVPR19
在搭建一个深度学习训练模板之前,我们 要明白,主要要做的事情有以下几件,下面算是几个较为宏观的步骤,这几个大的步骤里面会包含很多细节需要实现,这里记录的只是一个大概的模型搭建以及其中涉及到的代码知识,记下来以备自用自查

主要操作

  • 导入各种包
  • 解析输入的命令行参数
  • 设置随机种子
  • 创建dataset和dataloader
  • 搭建网络模型
  • 开始迭代训练并验证及保存模型

导入各种包

这一步省略,下面需要的时候会说

解析输入的命令行参数

#train.py
#创建参数的解析
parser=argparse.ArgumentParser(description='Zheng is training SR models')
parser.add_argument('-opt',type=str,required=True,help='path of JSON file')
opt=option.parse(parser.parse_args().opt)	

其中设计到的知识点:
(1)import argparse 首先导入模块
(2)parser = argparse.ArgumentParser() 创建一个解析对象
(3)parser.add_argument() 向该对象中添加你要关注的命令行参数和选项
(4)parser.parse_args() 进行解析
有关argparse模块,详见相关解析
需要注意的是我们这里并没有一个一个地去添加命令行参数,而是利用json文件来直接创建一个类似于字典的超参数集合,并将其实例化为名为opt的参数,以供接下来调用,这样做能够使代码很大简化,可读性更高,但有一个弊端就是要写另外的代码对json文件里面的参数进行解析,有关json文件里面具体参数的作用,详见上面的源代码链接。

设置随机种子

#train.py
seed=opt['solver']['manual_seed']
if seed is None: seed=random.randint(1,10000)
print("==>Random seed is [%s]"%seed)
random.seed(seed)
torch.manual_seed(seed)

其中torch.manual_seed()函数的使用请见相关解析

创建dataset和dataloader

#train.py
for phase,dataset_opt in sorted(opt['datasets'].items()):
    if phase=='train':
        train_set=create_dataset(dataset_opt)
        train_loader=create_dataloader(train_set,dataset_opt)
        print('Train Dataset:%s, Number of Images:%d'%(train_set.name(),len(train_set)))
        if train_loader is None:raise ValueError("还训练啥啊,训练集都没有")
    elif phase=='val':
        val_set=create_dataset(dataset_opt)
        val_loader=create_dataloader(val_set,dataset_opt)
        print('Val Dataset:%s, Number of Images:%d'%(val_set.name(),len(val_set)))

    else:
        raise NotImplementedError("[Error] Dataset phase [%s] in *.json is not recognized." % phase)

这一块的create_dataloader(),create_dataset()函数全是自己定义的,在/data/__init__.py文件下,下面对其进行解析

#/data/__init__.py
import torch.utils.data
'''有关dataloader的实现很简单,直接调用torch,。
def create_dataloader(dataset,dataset_opt):
    phase=dataset_opt['phase']
    if phase=='train':
        batch_size=dataset_opt['batch_size']
        shuffle=True
        num_workers=dataset_opt['n_workers']
    else:
        batch_size = 1
        shuffle = False
        num_workers = 1
    return torch.utils.data.DataLoader(
        dataset,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=True
    )

def create_dataset(dataset_opt):
    mode=dataset_opt['mode'].upper()
    if mode=='LR':
        from data.LR_dataset import  LRDataset as D
    elif mode=='LRHR':
        from data.LRHR_dataset import LRHRDataset as D
    else:
        raise NotImplementedError("Dataset [%s] is not recoginized by zhengzheng"%mode )
    dataset=D(dataset_opt)
    print('===>[%s] Dataset is created by zhengzheng' %mode)
    return dataset

有关LRHRDataset()类,其具体实现在/data/LRHR_dataset.py文件下,由于类里面涉及的操作较多,在此不对其进行源代码的解析,大致对其功能进行一个梳理,该类是在训练和验证阶段用的类,读取所有LR和HR图片的地址,并且 我们对其__getitem__魔法函数进行了改写, 它能够在该类被索引时返回一个字典{'LR': lr_tensor, 'HR': hr_tensor, 'LR_path': lr_path, 'HR_path': hr_path},即获取数据集内第idx张图像patch对。对于其他方面,如何读取npy类型的数据集,并将其转化为tensor文件;如何进行数据增强;如何得到图像对应的lr-hr的patch对,该类的实现具体来说相对复杂,有时间具体解析

搭建网络模型

#/train.py
solver=create_solver(opt)
print('芜湖,起飞==>Start training')
print('='*40)
solver_log=solver.get_current_log()

scale=opt['scale']
model_name=opt['networks']['which_model'].upper()
NUM_EPOCH=int(opt['solver']['num_epochs'])
start_epoch=solver_log['epoch']

print('Method: %s || Scale: %d || Epoch Range(%d~%d)'%(model_name,scale,start_epoch,NUM_EPOCH))
    # 开始迭代训练,验证

这里又挖了一个坑,create_solver()类是什么,先说它的功能吧,它不仅能够创建我们的网络模型,而且包含保存、加载已有网络模型;初始化网络;打印网络结构;计算网络参数量;self_ensemble等等你能想到的一系列只要跟网络模型相关的功能,有关其具体实现,等有时间一定具体解析

开始迭代训练并验证及保存模型

#/train.py
# 开始迭代训练,验证
for epoch in range(start_epoch,NUM_EPOCH+1):
    print('\n===> Training Epoch: [%d/%d], Learning Rate: %f'%(epoch,NUM_EPOCH,solver.get_current_learning_rate()))

    solver_log['epoch']=epoch

    train_loss_list=[]
    # 设置进度条
    with tqdm(total=len(train_loader),desc='Epoch:[%d/%d]'%(epoch,NUM_EPOCH),miniters=1) as t:
        # 通过迭代对train_set每一个batch训练
        for iter,batch in enumerate(train_loader):
            # copy源数据集训练数据
            solver.feed_data(batch)
            # 通过训练得到loss
            iter_loss=solver.train_step()
            batch_size=batch['LR'].size(0)
            train_loss_list.append(iter*batch_size)
            t.set_postfix_str("batch loss: %.4f "%iter_loss)
            t.update()

    # 记录当前loss,lr
    solver_log['records']['train_loss'].append(sum(train_loss_list)/len(train_set))
    solver_log['records']['lr'].append(solver.get_current_learning_rate())

    print('\nEpoch: [%d/%d] Avg train loss: %.6f'%(epoch,NUM_EPOCH,sum(train_loss_list)/len(train_set)))

    print('==>validating')
    psnr_list=[]
    ssim_list=[]
    val_loss_list=[]

    for iter,batch in enumerate(val_loader):
        solver.feed_data(batch)
        iter_loss=solver.test()
        val_loss_list.append(iter_loss)

        #
        visuals=solver.get_current_visual()
        psnr,ssim=util.calc_metrics(visuals['SR'],visuals['HR'],crop_border=scale)
        psnr_list.append(psnr)
        ssim_list.append(ssim)

        if opt['save_img']:
            solver.save_current_visual(epoch,iter)

    solver_log['records']['val_loss'].append(sum(val_loss_list)/len(val_loss_list))
    solver_log['records']['psnr'].append(sum(psnr_list)/len(psnr_list))
    solver_log['records']['ssim'].append(sum(ssim_list)/len(ssim_list))
    #         保存最好的epoch
    epoch_is_best=False
    if solver_log['best_pred']<(sum(psnr_list)/len(psnr_list)):
        solver_log['best_pred']=sum(psnr_list)/len(psnr_list)
        epoch_is_best=True
        solver_log['best_epoch']=epoch

    print("[%s] PSNR: %.2f   SSIM: %.4f   Loss: %.6f   Best PSNR: %.2f in Epoch: [%d]" % (val_set.name(),
                                                                                          sum(psnr_list)/len(psnr_list),
                                                                                          sum(ssim_list)/len(ssim_list),
                                                                                          sum(val_loss_list)/len(val_loss_list),
                                                                                          solver_log['best_pred'],
                                                                                          solver_log['best_epoch']))
    solver.set_current_log(solver_log)
    solver.save_checkpoint(epoch,epoch_is_best)
    solver.save_current_log()

    solver.update_learning_rate(epoch)

你可能感兴趣的:(深度学习,神经网络,计算机视觉)