numpy、os、torch、cv2
该部分的主要工作是完成数据的预处理、训练集测试集的划分以及数据集的读取,即得到train_dataloader、test_dataloader;
首先是数据的预处理部分,由于FCN不限制输入图片的尺寸大小,所以预处理部分较为精简,只需要转换为tensor格式,再进行一个标准化即可
之后便需要定义一个onehot函数,用于对标签图的处理,其作用是将标签转化为一维向量
先定义一个类,用于数据集的读取,其中,用于获取处理后的图片与标签的函数定义如下:
该函数在对象实例化后会自动调用,其中定义了原图与标签的打开、缩放、打开方式(指定标签以灰度图方式打开)、归一化、onehot化、调整长宽以及深度的位置、原图的transfrom调用
再对该类进行实例化,并对数据集进行训练与测试的划分,然后再使用DataLoader对数据集进行载入,得到处理好的数据集迭代对象,用于后续训练过程
此次编写的FCN采用的前置网络为VGG16,关于make_layer函数的编写只涉及全连接层以前,与VGG的写法完全一致,这里不多赘述,下面说一下vggnet类中有变化的地方:
首先定义参数:
pretrained——决定是否预训练, model——即模型名称,这里选用vgg16, requires_grad——是否保存梯度信息,决定了是否能进行反向传播, remove_fc——用于移除全连接层, show_params——用于展示网络层名与参数
首先通过继承、调用make_layer函数得到vgg前面的卷积池化层
super().__init__(make_layers(cfg[model]))
预训练和保存梯度信息的代码如下,这里选择的方案为只导入网络结构,不导入参数,所以pretrained设置为False,同时,给出不保存梯度信息的选择
去除全连接层十分简单,如下:
接下来定义正向传播过程,大致思路与编写vgg时一致,但需要考虑到后续再FCN类中的使用,于是在每一个池化层结束后中断一次,得到一个output,由于vgg中有五个池化层,所以最后output字典中会有五个output结构,分别为output[x1]到output[x5]
对于上代码中提到的ranges,解释如下:
ranges定义如下,是一个自定义的字典,主要是便于对各层网络的遍历及调用,在每个池化层结束后分割
分割依据为vgg网络结构,元组中所盛放数字的意义如下图(以vgg16参数为例),序号一一对应网络层结构
而关于上面提到的show_params参数,涉及一个.named_parameters()方法,该方法可返回网络层名与其中的具体参数,示例如下(并非本网络结构中代码,只是作为示例介绍方法):
下面是网络层名:
下面是具体参数:
其初始化函数定义如下:
其中几个deconv层是按照网络结构直接写出,bn层为对feature map的标准化处理,在resnet中有涉及,此处不赘述,需要注意的是最后一个classifier层,该层采用1*1卷积,目的为不改变长宽,而改变通道数为给出的类别数
接下来是正向传播层,直接对上面返回的output进行调用,再一次一次经过上采样,并与池化后的x相加再重复该过程(该方法可提高准确度,具体原因在另一篇博文里有提及,此处不赘述),最后经过一个classifier层,得到最终输出
首先进行模型实例化、设置device用于gpu加速,这些常规写法不做赘述;
接着是损失函数和优化器的选择,这里选择的是比较适用于二分类的问题的损失函数BCELoss,优化器选择SGD,使用准备好的参数组,实现随机梯度下降
接下来便可以进行训练过程,首先定义epoch个数,再遍历处理好的数据并载入,再进行梯度置0,然后将读取的数据丢进网络得到输出,通过torch.sigmoid()得到0~1之间的输出,再用损失函数计算loss,并反向传播,再使用优化器,具体代码如下:
其中原图片与标签是由上述遍历过程所得到的一组组对应的图片与标签
对于测试集的处理,过程基本一致,只需在最开始停止计算梯度即可
最后在每个epoch中计算出每次的loss值并print,再每5次保存一个权重文件
首先进行的数据预处理与前文相一致,再用权重文件实例化model,上述过程与CNN基本一致,此处不细讲;
但在将待预测图片放进网络之前,需对其进行一个 .unsqueeze(0) 操作,即在其第一个维度前再增加一个维度,其维度数为1;其目的是为了能顺利放进网络中,该1的实质是batchsize;
放进网络后,与train.py的处理相同,归一化之后便有所不同,代码如下:
第一行是停止对其进行梯度运算,再转换为numpy格式的副本;第二行是按行求含最小值的下标,目的是去掉对应了类别数2的那个维度,将输出由4维变为3维
最后利用np.squeeze()将最终输出的维度降为2维,并以灰度图形式读取,具体代码如下: