一、结构分布
先介绍一下代码的结构分布吧
1、tain.py文件是训练的时候首先执行的文件,里面的函数有eval()评估函数,train()训练函数
2、trainer.py文件是网络的流图,关于如何forward,如何计算loss,如何反向计算,如何保存模型,如何控制权重更新等等,这个里面的函数会在train.py中的train()函数开始的时候调用,先构建fasterrcnn的网络,然后将网络作为参数传给trainer的构造函数
3、data文件夹,下面都是数据读取的方法
(1)dataset.py文件是批量加载数据,这个也会在train()函数开始初始化了一个dataset对象。
(2)voc_dataset.py文件是针对VOC数据集格式准备的,批量加载VOC数据集,解析XML文件都在这个类中,而且他是在dataset.py文件中调用,加载VOC数据集用的。
(3)util.py文件是定义了一些图像预处理工具,包括read_image,resize_box、crop_bbox这些会在其他py文件中调用,比如voc_dataset.py文件中调用了read_image数据进行读取数据。
(4)__init__.py文件好像是python要求类下必须有这个文件,需要确定一下???
4、model文件夹,下面都是网络构建的一些py。
(1)faster_rcnn_vgg16.py文件,是用来构建FasterRCNN-vgg16网络的,该网络分三部分创建,extractor特征提取网络,是利用torchvision.model模块创建的VGG16,然后是RPN网络创建,再就是ROIHeader网络创建,这个网络构建对象在train()函数开始的时候就创建了,用于先构建网络,然后再传入trainer。
(2)faster_rcnn.py文件,是一个base-class,faster-rcnn-vgg16类继承了这个类
(3)region_proposal_network.py文件,用于构建rpn模块,在faster_rcnn_vgg16.py文件中调用生成网络结构。
(4)roi_module.py文件,这个暂时没研究,在faster_rcnn_vgg16.py中调用了,初始化的时候ROIPooling。??
(5)utils文件夹,下面是一些工具,nms文件夹是非极大值抑制,其他的没仔细看,后期研究,主要是在faster_rcnn.py文件中调用了。
6、utils文件夹,这里面是工具
二、训练流程
1、首先调用train.py文件,输入相关参数进行训练。输入的控制台参数用**kwargs来表示,学习了python的控制台参数知道这是个接受字典形参数。https://www.cnblogs.com/zhangzhuozheng/p/8053045.html可参见这个地址有详细说明。
2、进行参数解析,利用了utils文件夹下的config.py文件进行了参数获取,这个文件中自定义一些默认参数,主要返回的是学习率学习策略,数据集地址等。
3、构造数据集对象,包括标签、图像名称列表。
4、根据batch_size,num_workers进行数据加载对象声明。数据是不是这个时候加载的还待定,感觉这个loader就像一个占位符,先占个坑,等运行网络的时候就开始读入了。具体DataLoader的用法需要查看pytorch
5、然后构建网络faster_rcnn_vgg16,构建方式见前面说的faster_rcnn_vgg16.py。
6、构建traner对象,将网络输入进行。
7、判断opt.load_path是否存在,这个load_path是在config.py中定义的,是model的地址,默认值是none,也就是如果在控制台输入没有指定--model这个参数,那么就没有了。如果有model即预训练模型,则调用trainer中的load函数加载预训练模型。
(1)trainer.load()函数解析,首先利用torch.load()函数加载模型,然后判断‘model’字符串是否存在,来判断是单纯加载参数还是加载带模型的参数(这是我个人理解的,具体要看pytorch的load_state_dict函数),最后判断参数是否修改,默认没改,最后判断优化器是否在加载的网络里,是的话加载预训练模型中的优化器。
8、可视化训练数据的label,调用trainer.vis.text函数,函数解析待会。
9、best_map参数干什么用的不知道待定,lr_是学习率获取。
10、下面就是循环训练啦,循环条件是epoch数,这个是opt超参数规定的。
(1)trainer.reset_meters()先重置界面上所有的数据,相当于一个epoch更新一次显示数据。
(2)开启一个for循环,枚举数据啦,从dataloader中按照batch-size循环读取数据,循环条件是把数据取完,tqdm模块是进度条模块具体可以百度。
(3)然后调用array_tool.py文件(在utils文件夹下)中的scalar()函数,传入参数是scale,这个参数是什么意思
(4)把数据传到cuda中,用来加速计算,返回转换后的cuda版本的数据,下面调用的都是cuda版的。
(5)利用trainer.train_step函数进行计算,前面介绍过这个函数,是用来更新一次权重的。
(6)图像进行归一化处理.
(7)然后是显示
(8)预测bboxes,label,这个predict函数是哪来的呢,首先是trainer,而trainer中调用的网络是fasterrcnnvgg16,这个网络继承 的是fasterrcnn的类,fasterrcnn类中有一个predict函数。
(9)下面就是一些可视化操作了,然后跳出了枚举数据的循环。这就是1个epoch完成了
(10)模型评估,利用测试集来做,前面已经加载了测试集,test_dataloader
(11)得到优化器中学习率的数值,并显示日志相关内容,包括lr,map,loss
(12)根据评测结果判断map是否是大于阈值best_map,如果是保存模型
(13)判断当前的epoch是否=9,如果是就加载最好的map和改变学习率
(14)判断如果epoch=13就跳出迭代循环???这个是这个实验里设计的具体原因不清楚。待定