首先,要搭建的训练代码是实验室师兄在CVPR2019上发表的已开源文章SRFBN,非常感谢师兄的搭建的深度学习模板。
paper:Feedback Network for Image Super-Resolution
github:https://github.com/Paper99/SRFBN_CVPR19
在搭建一个深度学习训练模板之前,我们 要明白,主要要做的事情有以下几件,下面算是几个较为宏观的步骤,这几个大的步骤里面会包含很多细节需要实现,这里记录的只是一个大概的模型搭建以及其中涉及到的代码知识,记下来以备自用自查
这一步省略,下面需要的时候会说
#train.py
#创建参数的解析
parser=argparse.ArgumentParser(description='Zheng is training SR models')
parser.add_argument('-opt',type=str,required=True,help='path of JSON file')
opt=option.parse(parser.parse_args().opt)
其中设计到的知识点:
(1)import argparse 首先导入模块
(2)parser = argparse.ArgumentParser() 创建一个解析对象
(3)parser.add_argument() 向该对象中添加你要关注的命令行参数和选项
(4)parser.parse_args() 进行解析
有关argparse模块,详见相关解析
需要注意的是我们这里并没有一个一个地去添加命令行参数,而是利用json文件来直接创建一个类似于字典的超参数集合,并将其实例化为名为opt
的参数,以供接下来调用,这样做能够使代码很大简化,可读性更高,但有一个弊端就是要写另外的代码对json文件里面的参数进行解析,有关json文件里面具体参数的作用,详见上面的源代码链接。
#train.py
seed=opt['solver']['manual_seed']
if seed is None: seed=random.randint(1,10000)
print("==>Random seed is [%s]"%seed)
random.seed(seed)
torch.manual_seed(seed)
其中torch.manual_seed()函数的使用请见相关解析
#train.py
for phase,dataset_opt in sorted(opt['datasets'].items()):
if phase=='train':
train_set=create_dataset(dataset_opt)
train_loader=create_dataloader(train_set,dataset_opt)
print('Train Dataset:%s, Number of Images:%d'%(train_set.name(),len(train_set)))
if train_loader is None:raise ValueError("还训练啥啊,训练集都没有")
elif phase=='val':
val_set=create_dataset(dataset_opt)
val_loader=create_dataloader(val_set,dataset_opt)
print('Val Dataset:%s, Number of Images:%d'%(val_set.name(),len(val_set)))
else:
raise NotImplementedError("[Error] Dataset phase [%s] in *.json is not recognized." % phase)
这一块的create_dataloader(),create_dataset()
函数全是自己定义的,在/data/__init__.py
文件下,下面对其进行解析
#/data/__init__.py
import torch.utils.data
'''有关dataloader的实现很简单,直接调用torch,。
def create_dataloader(dataset,dataset_opt):
phase=dataset_opt['phase']
if phase=='train':
batch_size=dataset_opt['batch_size']
shuffle=True
num_workers=dataset_opt['n_workers']
else:
batch_size = 1
shuffle = False
num_workers = 1
return torch.utils.data.DataLoader(
dataset,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=True
)
def create_dataset(dataset_opt):
mode=dataset_opt['mode'].upper()
if mode=='LR':
from data.LR_dataset import LRDataset as D
elif mode=='LRHR':
from data.LRHR_dataset import LRHRDataset as D
else:
raise NotImplementedError("Dataset [%s] is not recoginized by zhengzheng"%mode )
dataset=D(dataset_opt)
print('===>[%s] Dataset is created by zhengzheng' %mode)
return dataset
有关LRHRDataset()
类,其具体实现在/data/LRHR_dataset.py
文件下,由于类里面涉及的操作较多,在此不对其进行源代码的解析,大致对其功能进行一个梳理,该类是在训练和验证阶段用的类,读取所有LR和HR图片的地址,并且 我们对其__getitem__
魔法函数进行了改写, 它能够在该类被索引时返回一个字典{'LR': lr_tensor, 'HR': hr_tensor, 'LR_path': lr_path, 'HR_path': hr_path}
,即获取数据集内第idx张图像patch对。对于其他方面,如何读取npy类型的数据集,并将其转化为tensor文件;如何进行数据增强;如何得到图像对应的lr-hr的patch对,该类的实现具体来说相对复杂,有时间具体解析。
#/train.py
solver=create_solver(opt)
print('芜湖,起飞==>Start training')
print('='*40)
solver_log=solver.get_current_log()
scale=opt['scale']
model_name=opt['networks']['which_model'].upper()
NUM_EPOCH=int(opt['solver']['num_epochs'])
start_epoch=solver_log['epoch']
print('Method: %s || Scale: %d || Epoch Range(%d~%d)'%(model_name,scale,start_epoch,NUM_EPOCH))
# 开始迭代训练,验证
这里又挖了一个坑,create_solver()
类是什么,先说它的功能吧,它不仅能够创建我们的网络模型,而且包含保存、加载已有网络模型;初始化网络;打印网络结构;计算网络参数量;self_ensemble等等你能想到的一系列只要跟网络模型相关的功能,有关其具体实现,等有时间一定具体解析。
#/train.py
# 开始迭代训练,验证
for epoch in range(start_epoch,NUM_EPOCH+1):
print('\n===> Training Epoch: [%d/%d], Learning Rate: %f'%(epoch,NUM_EPOCH,solver.get_current_learning_rate()))
solver_log['epoch']=epoch
train_loss_list=[]
# 设置进度条
with tqdm(total=len(train_loader),desc='Epoch:[%d/%d]'%(epoch,NUM_EPOCH),miniters=1) as t:
# 通过迭代对train_set每一个batch训练
for iter,batch in enumerate(train_loader):
# copy源数据集训练数据
solver.feed_data(batch)
# 通过训练得到loss
iter_loss=solver.train_step()
batch_size=batch['LR'].size(0)
train_loss_list.append(iter*batch_size)
t.set_postfix_str("batch loss: %.4f "%iter_loss)
t.update()
# 记录当前loss,lr
solver_log['records']['train_loss'].append(sum(train_loss_list)/len(train_set))
solver_log['records']['lr'].append(solver.get_current_learning_rate())
print('\nEpoch: [%d/%d] Avg train loss: %.6f'%(epoch,NUM_EPOCH,sum(train_loss_list)/len(train_set)))
print('==>validating')
psnr_list=[]
ssim_list=[]
val_loss_list=[]
for iter,batch in enumerate(val_loader):
solver.feed_data(batch)
iter_loss=solver.test()
val_loss_list.append(iter_loss)
#
visuals=solver.get_current_visual()
psnr,ssim=util.calc_metrics(visuals['SR'],visuals['HR'],crop_border=scale)
psnr_list.append(psnr)
ssim_list.append(ssim)
if opt['save_img']:
solver.save_current_visual(epoch,iter)
solver_log['records']['val_loss'].append(sum(val_loss_list)/len(val_loss_list))
solver_log['records']['psnr'].append(sum(psnr_list)/len(psnr_list))
solver_log['records']['ssim'].append(sum(ssim_list)/len(ssim_list))
# 保存最好的epoch
epoch_is_best=False
if solver_log['best_pred']<(sum(psnr_list)/len(psnr_list)):
solver_log['best_pred']=sum(psnr_list)/len(psnr_list)
epoch_is_best=True
solver_log['best_epoch']=epoch
print("[%s] PSNR: %.2f SSIM: %.4f Loss: %.6f Best PSNR: %.2f in Epoch: [%d]" % (val_set.name(),
sum(psnr_list)/len(psnr_list),
sum(ssim_list)/len(ssim_list),
sum(val_loss_list)/len(val_loss_list),
solver_log['best_pred'],
solver_log['best_epoch']))
solver.set_current_log(solver_log)
solver.save_checkpoint(epoch,epoch_is_best)
solver.save_current_log()
solver.update_learning_rate(epoch)