基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现

之前读过这篇论文,导师说要复现,这里记录一下。废话不多说,再重读一下论文。
注:非一字一句翻译。个人理解,一定偏颇。

基于深度强化学习的车道检测和定位

官方源码下载:https://github.com/tuzixini/DQLL
论文原文:https://www.sciencedirect.com/science/article/pii/S0925231220310833
如需引用:
@article{zhao2020deep,
title={Deep reinforcement learning based lane detection and localization},
author={Zhao, Zhiyuan and Wang, Qi and Li, Xuelong},
journal={Neurocomputing},
volume={413},
number={6},
pages={328-338},
doi={10.1016/j.neucom.2020.06.094},
year={2020}
}

摘要

基于深度学习的车道检测方法只检测带有粗略边框的车道线,而忽略了特定曲线车道的形状。针对上述问题,本文将深度强化学习引入粗车道检测模型中,以实现精确的车道检测和定位。该模型由**边界盒探测器(bounding box detector)和地标点定位器(landmark point localizer)**两个阶段组成。边界盒级卷积神经网络车道检测器以边界盒的形式输出车道的初始位置。然后,基于强化学习的深度Q-Learning定位器(Deep Q-Learning Localizer,DQLL)将车道作为一组地标进行精确定位,以更好地表征曲线车道。构造并发布了一个像素级车道检测数据集NWPU车道数据集。它包含了各种真实的交通场景和精确的车道线遮罩。该方法在发布数据集和存储数据集上都取得了较好的性能。

1 引言

避免事故发生和引导车辆沿着适当的车道行驶是辅助系统的两项基本任务,实现这两个目标的几个技术手段:车道检测,道路检测,前方车辆碰撞预警,交通标志检测,交通拥堵检测,道路标记检测。车道检测在上述任务和其他高级驾驶辅助目标中有着不可替代的作用,如可行驶区域检测和自动泊车。图1显示了表示车道的不同的方法,包括直线、边框、地标和像素掩码。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第1张图片
在深度卷积神经网络(Deep Convolutional Neural Networks, DCNN)广泛研究和应用之前,很多工作都是使用低级特征提取器来检测车道线,并使用多条直线来表示车道。直线在直线车道上很好,但在曲线车道上就不行了。为了解决曲线车道的表示问题,在车道检测中引入边界盒和像素级掩码。但是,边界盒的精度不够高,像素级掩模的预测需要复杂的计算。
为了解决上述问题,我们提出了一种基于深度强化学习的车道检测和定位网络。它由深度卷积巷边界盒检测器(deep convolutional lane bounding box detector)和深度q学习定位器(Deep Q-Learning localizer)组成。所提网络的结构示意图如图2所示。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第2张图片
它是一个两阶段的顺序处理架构。具体来说,第一阶段是一个改进的Faster R-CNN[27],它以包围盒的形式检测道。第二阶段为轻量化深度Q-learning地标定位器,由五层卷积层和三层全连接层组成。
在检测阶段得到边界盒后,初始地标沿边界盒对角线均匀分布。然后车道定位任务变成了一个点移动博弈游戏。车道定位器在游戏中扮演Agent的角色。agent需要做的是根据当前环境状态将地标向特定方向移动当前环境状态包括当前点位置、动作历史向量和已编码的图像特征。最后,当agent决定不再移动路标点时,将所有路标点的位置输出为车道的位置。
为了验证所提方法的有效性,我们建立了一个名为NWPU Lanes dataset的像素级车道数据集,该数据集包含1964个交通场景图像,并带有标记良好的像素级车道遮罩。
contributions:

  • 定义了一种新的车道检测和定位表示方法,达到了精度和计算量的平衡。
  • 深度Q-Learning车道定位器(DQLL)将车路定位为一组地标,对曲线车道进行了较好的表征。
  • 构建一个像素级车道数据集NWPU车道数据集,其中包含精心标注的城市图像,有助于发展交通场景的理解。

2 相关工作

2.1 传统车道线检测方法

  • 为车道线构造易于识别的特征,根据其相同的特征手工设计特征表示。常用的特征提取器如Hough变换[21]和Dark-Light-Dark
    (DLD)[22]在简单的条件下是有效的,但是在复杂的场景下性能会迅速下降。它们对噪音的敏感导致了这个问题。
  • 逆透视映射(IPM)[28]将原始图像转换为鸟瞰图。然后使用上面的特性提取器在这个视图下生成特性。视图转换有助于减少冗余信息并增强目标表示。但在非常复杂的情况下,其效果明显下降。其根本原因是DLD、霍夫变换等低级提取器提取的特征不够强大。

2.2 基于深度学习的车道线检测

DCNN可以从输入图像中生成具有足够高水平语义信息的特征。此外,它的自动拟合特性节省了大量的特征设计工作。

2.3 强化学习

Mnih等[39]将Q-Learning与deep Q-Learning Network (DQN)中的深度学习方法相结合,即使用神经网络代替Q-table。

3 方法论

3.1 概述

本方法由检测和定位部分组成。检测部分的目的是获得车道的初步边界盒位置。为了配合下一阶段的定位过程,我们仔细考虑了车道的特点,通过观察如图3所示的各种车道,我们总结如下:

  • 边界框框出的车道线总是靠近矩形对角线
  • 左上到右下:视野的左边;左下到右上:视野的右边 (我总感觉这里写反了的样子。。)*
  • 边界框的对角线大致可以用来表示直线的位置。对于弯曲车道,由于车道形状的巨大差异,它失败了。
    车道所经过的对角线将是确定路标点初始位置的关键因素。

基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第3张图片

3.2 车道线检测

首先在车道数据集上对改进后的公共目标检测器进行再训练,并将其用于获得车道的初始位置。从技术上讲,几乎所有典型的对象检测器,如[40-43]都可以在这里使用。**第一阶段我们采用Faster R-CNN作为基线车道边界盒检测器。**通过检测车道坡度,将车道划分为不同类型,并将车道类型与3.1节讨论的内容统一起来。检测阶段完整工作流程如图4所示。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第4张图片
Faster R-CNN使用CNN完成proposal生成、回归和分类。一个输入图像在网络中只传播一次,提高了网络的效率。
用VGG作为CNN的骨架。它从输入的三帧RGB图像中提取卷积特征映射。然后将整个图像特征向量和图像信息发送到RPN,生成区域建议。ROI pooling层有助于将区域建议的特征向量强制为固定大小。建议回归网络和建议分类网络分别使用多个全连通层来得到边界框偏差和分类概率。网络的详细结构如图4所示,其中Conv表示卷积层,Dense表示完全连通层。
车道检测阶段的最终输出为输入图像内所有车道的边界框位置和车道类型。

3.3 车道线定位

我们使用五个地标点来准确定位车道。地标在边界框内统一初始化。这样,定位阶段就变成了一个点移动博弈游戏,目标是将所有的地标移动到正确的位置。应用一种基于强化学习的深度Q-learning车道定位器来进行游戏。与边界框相比,地标有效地提高了曲面车道的表示能力,提供了更精确的位置信息。

3.3.1 游戏定义

如图3所示,经检测阶段的每个包围盒与盒子在水平方向上通过5条截止线分割成6个相等的区域。车道线所沿的对角线与这五条分割线在几个点相交。这些点被用作地标点的初始位置。
我们尝试通过深度强化学习方法来解决点定位博弈问题。这里使用的学习策略是Q-Learning[38]方法。在原来的Q表中,它对每个不同的环境状态进行了重新编码,哪个行为选择会导致最高的回报。初始Q表给出随机的行动决策,它根据以下公式随训练过程更新:
在这里插入图片描述
(关于公式的理解和别的地方一样,这里不再赘述。)
除了Q表之外,环境状态、行动选项和奖励功能共同构成了深度Q学习的过程。下面的小节将详细介绍这三个关键组件。

3.3.2 环境状态

环境状态包含了影响行动决策结果的因素。对于这个移动点游戏,当前选择的地标点的位置信息,以及图像块都有助于找到正确的位置。我们还考虑了之前已经做过的动作,我们称之为动作历史向量。(就是引言里说过的三部分组成:当前点位置、动作历史向量和已编码的图像特征)
在这里插入图片描述
S是当下环境状态,等式右侧第一项是已编码的图像特征(Ib是边界盒框起来的部分),第二项是当前点位置,第三项是动作历史向量。中间的符号表示concatenate操作。
图5表示了环境状态的组成。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第5张图片

3.3.3 动作空间

地标点的纵坐标是一个固定的值,所以它只能水平移动。我们人为地定义了三种可选的操作类型。

  • delete action:agent决定删除当前点或采取其他操作。偏离范围或距离实际车道位置太远的点可能被删除。
  • moving action:对正常范围下的地标点,agent将点移向正确的方向,这些点沿水平线有两个移动方向,因此移动动作包含向左或向右的运动。
  • terminal action:当点与期望位置足够接近时,agent必须判断当前位置是否为最终位置。终端动作决定截断点移动过程或进入下一个动作选择。

所有的动作选择以及相应的实际像素级点移动如表1所示。其中x表示当前地标点的位置。
表1

3.3.4 奖励函数

我们根据行动选择所导致的结果将其分为三种类型。

  • Invalid Action Choices:动作a将地标点移出了适当的图像范围,删除了应该保留的点或保留了应该删除的点。

在这里插入图片描述

  • Regular Action Choices:如果这个动作选择不是前面提到的无效的,而是一个移动的动作,我们称这个选择为常规的动作选择。我们定义当前点位置之间的距离和环境状态下点的真实位置为d(s)。
    d(s’)新距离
    在这里插入图片描述

点离真实距离更近得一分,否则扣一分。

  • Terminal Action Choices:agent做出终止点移动过程的决策。当当前点和地面真实位置足够接近时,agent通过terminal action choices获得正分数。否则,如果代理在不合适的时间停止移动过程,则会得到负的分数。

在这里插入图片描述

3.3.5 概述

完整的车道定位工作流程如图6所示。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第6张图片
首先将检测到的所有边界框从完整图像中截断,然后将其调整为统一的尺寸**[100, 100, 3]**,然后再送到Deep Q-Learning Localizer(DQLL)网络中。定位器根据边界盒的类型初始化5个地标点,即5个地标的水平和垂直坐标均匀分布在0 ~ 100之间。此外,这五个点对于不同的车道可能沿着不同的斜线。初始化完成后,分别对5个地标点进行定位。
图6显示了右侧的决策网络结构。具体的网络架构如表2所示。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第7张图片
“Conv1: (k(3,3),c(3->48),s(1),NoPadding)”:意味着卷积层名为Conv1使用48个卷积核的大小3x3x3与strides=2 (这里我认为该是1)和nopadding。
‘‘FC: 5393 -> 512” :全连接层,输入尺寸为5393,输出尺寸为512。
动作历史向量的长度影响第一完全连接层的输入端形状。这里我们使用四个过去的动作来形成动作历史向量。因此,第一个完全连接层的输入大小为1 x 21 x 256 + 1 + 4 x 4 = 5393。
1 x 21 x 256:特征编码器的输出
1 + 4 x 4:四个步长的动作历史向量
(我认为1是当前点位置)

损失函数:MSE
在这里插入图片描述

4 数据集

本文提出具有像素级别标签的NWPU数据集,它来自于真实驾驶场景下录制的视频。
由于人工采集和标注真实场景数据比较困难,我们从数据集[47]中选择一些虚拟数据来辅助我们的训练过程。这些图像和标签掩码都是由带有精确注释的软件生成的合成数据。

4.1 NWPU车道线数据集

汽车视频数据记录器收集了13个真实驾驶场景的视频。其中12个片段是3分钟长,剩下的一个是1分38秒长。经过每秒1帧的采样,总共得到2258张初始图像。通过对这些图片的观察,我们发现在实际驾驶过程中,由于车辆停车、拥堵、遮挡等问题,仍然有很多图片无法用于初步样本。所以我们又手动删除了一张图片,删除后保留了1964张图片。接下来,我们使用自己开发的打标工具进行准确的像素级线打标。标记完成后,可以通过像素级标记生成包围盒。
图7左侧的部分(A)展示了来自NWPU lane数据集的驾驶场景及其相应的掩码。最终得到的数据分为训练集和测试集,其中测试集占20%。数据分布如表3所示。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第8张图片

4.2 合成数据集

手工采集的数据在标注工作中会不可避免的出现误差。在车道线定位的过程中,输入图像是一个小尺寸的图像,只包含从原始图像中截断的车道线区域。因此,在定位过程中,原始图像级的误差可能会放大。与人工标注相比,软件生成的虚拟数据具有完全准确的像素级标注。因此,我们不仅使用自己构建的真实数据集,还从其他数据集中选择适当场景下的虚拟数据进行训练和测试。SYNTHIA数据集包含大量生成的数据,这些数据是在不同的场景、时间、季节和天气设置中构建的。我们手动选择一些接近NWPU lane数据集的场景。这些虚拟场景数据的加入有效地促进了DQLL的学习过程。我们还将这些数据拆分为列车集和测试集,拆分规则与自构建数据集一致,详细信息如表3所示。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第9张图片
(表格两行好像搞反了)

5 实验和讨论

5.1 评价指标

定义了两个评价指标,这两个指标的定义与第3节车道线定位的具体实现密切相关。这两个指标都是在一个完整的测试过程中定义的。

  • Hit Rate α:反映定位精度。‘‘Hit Point”:最终终止点与ground truth位置点之间的距离小于5像素。

在这里插入图片描述
分子:在特定测试期间命中的总点数
分母N:所有地标点的个数,也等于5倍的边界盒数
0<α<1。愈高愈好

  • Average Step:反应定位速度。

在这里插入图片描述
S:本测试期内所有地标点的行动步骤总数,
N:同上式。

5.2 实验设置

NVIDIA GTX 1080Ti GPU
Inter Core [email protected] GHz CPU
Ubuntu 14.10
TensorFlow [48] or PyTorch
两个阶段依赖于相当独立的实验设置。
边界盒检测阶段:在基于像素级lane数据集的基础上,采用简单的连通分量检测算法生成检测阶段所需的边界盒。Faster R-CNN在训练过程中,批大小为1,边界盒分类的批大小为300。学习率和权值衰减分别设置为0.001和0.0005。
定位阶段:在DQLL的训练过程中,需要与检测阶段不同的地面真实数据。我们进一步对第一步的边界框数据进行处理。首先,对每个矩形盒分别从原始图像中截断;然后,每个盒子被5个分割器分成6个等面积的水平矩形。分隔线与车道线相交于一小段。**最后,将短截面的中点作为对应地标点的ground truth。**直到此时,训练DQLL的数据才准备好。定位阶段的学习率设置为0.0001,批量大小为1024进行训练。

5.3 实验结果

验证DQLL的效果,在NWPU 和 TuSimple 数据集分别进行验证。

  • 于NWPU数据集:表4给出了DQLL与检测、分割等车道检测方法相结合的测试结果。如我们所料,基于分割的初始精度高于基于检测的初始精度。

基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第10张图片

  • 于TuSimple数据集:3626张用于训练,2782张用于测试。我们将标记的数据转换为DQLL训练和测试所需要的形式。表5显示了不同方法的结果。

基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第11张图片
TuSimple Lane数据集的总命中率高于NWPU Lane数据集。这意味着前者有更多的直线。实验结果表明,平均步数随着第一阶段初始检测效果的提高而减少。因为更可靠的检测结果可以在初始化地标点时提供更好的指导。
表6
从表4-6所示的实验结果可以看出,无论使用何种实验设置,DQLL都可以在四步之内完成一个地标点的定位过程,对于NWPU Lanes数据集,平均只需要三步,对于TuSimple Lane数据集,平均只需要不到两步。显然,动作历史向量的长度会影响命中率和平均步长。选择合适的长度变得很重要,我们在5.4小节中进行具体分析。

  • DQLL 可视化
    图8展示了DQLL定位过程的可视化。所有子图由五行组成。每一行对应于一个地标点的定位过程。蓝色圆圈是地面真相地标的置信区间,绿色的点是当前移动的地标点。因此,对于每一个子图,左上角是初始点位置,右下角是定位过程的最终输出。子图(A)和©展示了正常曲线车道线的细节定位过程。与初始地标位置相比,DQLL有效地移动了地标以更好地拟合曲线。子图(B)是一条直线,但偏离了边界盒的中心,因此去掉了两个地标点,只剩下三个点。第四个子图给出了一个反面的例子,五个标志性点的初始位置在预期范围内,但是DQLL错误地将其移出了预期范围。
    基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第12张图片
    基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第13张图片

基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第14张图片
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第15张图片
在这里插入图片描述

5.4 动作历史向量的影响

动作历史向量的长度设置为范围为[0, 10],步长为2。所有实验结果如表6所示。
与没有动作历史向量的DQLL相比,不管历史向量有多长,有了这个向量的DQLL都提高了定位精度,减少了平均需要的步长。这说明过去的动作选择在一定程度上有助于当前状态下的决策。然而,并不是越长越好。历史向量长度为4时,命中率最好;历史向量长度为6时,平均步长最好。这两个评估指标都倾向于随着动作历史向量长度的增加而从增加变为减少。因此,选择适当长度的历史步骤是至关重要的。

5.5 与监督学习方法对比

我们手工设计了几种深度监督学习(DSL)方法,这些模型具有与DQLL完全相同的网络结构。唯一的区别是这些DSL方法尝试直接回归拟合五个地标点的位置,并且它们的训练损失函数不同。在此,我们分别使用MSE、L1和 smooth L1损失函数来训练DSL模型。实验结果如表7所示。图9说明了所比较的DSL模型的网络架构。
表7
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第16张图片
比较结果表明,学习如何移动地标点到正确的位置比直接返回地标点的位置更有效。

6 结论

在未来,我们将尝试利用更多的先验知识来提高检测性能。

代码解读

$CODEROOT:放置此代码的路径。
$DATAROOT:放置数据集的路径。

准备

Python 3.x
Pytorch 1.x

确保你的代码目录如下:
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第17张图片
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第18张图片

下载图森数据集

我们需要的是来自“LANE DETECTION CHALLENGE”的数据。
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第19张图片
确保你的**$DATAROOT**目录如下:
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第20张图片
使用genMyData.pygenMeanImg.py生成新数据。

  1. 修改变量DATAROOT(文件getMyData.py , 210行)。到您使用的实际的$DATAPATH(这里是r"/opt/disk/zzy/dataset/TuSimple")。
DATAROOT = r"/opt/disk/zzy/dataset/TuSimple"
  1. 转到$CODEPATH并运行genMyData.py,等待完成。
cd $CODEPATH
python genMyData.py

genMyData.py

# coding=utf-8
# [email protected]
# WIN10 Python3.6.6
# 用途: 处理TuSimpleLane 数据集
# 生成需要的数据
# genMyData.py

# 导入模块
import json
import os
import os.path as osp
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import shutil
from tqdm import tqdm

faileList = []

def doit(sroot, jlistpath, droot,namelist=[],DRL_list=[],DRLcount=0):
    os.makedirs(droot, exist_ok=True) # 创建目标目录
    os.makedirs(osp.join(droot, 'img'), exist_ok=True)
    os.makedirs(osp.join(droot, 'mask'), exist_ok=True)
    os.makedirs(osp.join(droot, 'mask_color'), exist_ok=True)
    os.makedirs(osp.join(droot, 'bbox'), exist_ok=True)
    os.makedirs(osp.join(droot, 'DRL'), exist_ok=True)
    os.makedirs(osp.join(droot, 'DRL','ori'), exist_ok=True)
    os.makedirs(osp.join(droot, 'DRL', 'resize'), exist_ok=True)
    jlist =[]
    with open(jlistpath, 'r') as f: # 把Json文件按行读入
        for line in f.readlines():
            jlist.append(json.loads(line)) # 放到jlist中(第一步共有358行)
    for ins in tqdm(jlist):
        faileIns = dict()
        imgpath = ins['raw_file']
        temp = imgpath.split('/')
        newname = temp[1] + temp[2] + temp[3][:-4]
        namelist.append(newname)
        spath = osp.join(sroot, imgpath) # 源目标文件路径
        # 计算
        try:
            temp = getInfo(spath, ins)
        except:
            faileIns['sroot'] = sroot
            faileIns['jlistpath'] = jlistpath
            faileIns['ins'] = ins
            faileList.append(faileIns) # 若程序运行不正常,faileList存储报错次数(根据笔者半天时间亲测:运行的第一个文件中第281行存在错误。https://blog.csdn.net/songyuc/article/details/109769131   存在一个边界框,对y=91时没有车道线,所以会报错。其他文件也存在一些报错,不过不是很多,不影响运行
            continue
        mask, mask_color, bbox, box, box_mask, gt, rebox, rebox_mask, regt = temp
        if len(box) > 5:
            faileIns['sroot'] = sroot
            faileIns['jlistpath'] = jlistpath
            faileIns['ins'] = ins
            faileIns['Reason']="cot>5"
            faileList.append(faileIns) # 若box(原检测出图片中车道线数量>5(存在误检,实际上不一定大于5,则记录他们源目录、json文件目录、边界框点集、原因)
            continue
        #  copyimg
        dpath = osp.join(droot, 'img', newname + '.jpg')
        shutil.copy(spath, dpath) # 将spath的文件复制到dpath。详见https://www.cnblogs.com/liuqi-beijing/p/6228561.html
        # mask
        dpath = osp.join(droot, 'mask', newname + '.png') 
        mask = Image.fromarray(mask.astype('uint8'))# 实现array到image的转换。详见https://blog.csdn.net/weixin_39450145/article/details/103874310
        mask.save(dpath) # 生成图片并保存(全是黑色)【有车道线地方为1,无车道线地方定义为0】
        # mask_color
        dpath = osp.join(droot, 'mask_color', newname + '.png')
        mask_color = Image.fromarray(mask_color.astype('uint8'))
        mask_color.save(dpath) # 车道线为255,无车道线为1【车道线为白色,无车道线为黑色】
        # bbox
        dpath = osp.join(droot, 'bbox', newname + '.json')
        with open(dpath,'w') as f:
            json.dump(bbox,f) # 注:bbox是四个点的形式存在。有时可能会将一条车道线检测为两条,不过不影响后续定位。
        # box 裁剪出来的图片
        for i in range(len(box)):
            # 获取裁剪出来图片的名称
            DRLname = newname + '_' + str(DRLcount)
            DRL_list.append(DRLname)
            # box
            temp = box[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot,'DRL','ori',DRLname+'.png')
            temp.save(dpath) # 依次检测原图片中的每一条车道线
            # boxmask
            temp = box_mask[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'ori', DRLname+'_mask.png')
            temp.save(dpath) # 掩码,基本黑色(有车道线地方为1,无车道线地方定义为0)
            # boxmask_color
            boxmask = box_mask[i]
            temp = np.zeros(boxmask.shape)
            temp[boxmask == 1] = 255
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'ori', DRLname + '_mask_color.png')
            temp.save(dpath) # 掩码 车道线为255,无车道线为1【车道线为白色,无车道线为黑色】
            # gt
            dpath = osp.join(droot,'DRL','ori',DRLname+'.json')
            with open(dpath, 'w') as f:
                json.dump(gt[i], f) # 真实车道线类别和坐标
            # rebox
            temp = rebox[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'resize', DRLname+'.png')
            temp.save(dpath) # 生成缩放后的边界框图片(从原图像截取后再缩放) 100x100
            # reboxmask
            temp = rebox_mask[i]
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'resize', DRLname+'_mask.png')
            temp.save(dpath) # 掩码,基本黑色(有车道线地方为1,无车道线地方定义为0) 100x100
            # reboxmask_color
            boxmask = rebox_mask[i]
            temp = np.zeros(boxmask.shape)
            temp[boxmask == 1] = 255 # 将rebox_mask为1的像素点置为255
            temp = Image.fromarray(temp.astype('uint8'))
            dpath = osp.join(droot, 'DRL', 'resize',DRLname + '_mask_color.png')
            temp.save(dpath) # 车道线为255,无车道线为1【车道线为白色,无车道线为黑色】 100x100
            # regt
            dpath = osp.join(droot, 'DRL', 'resize', DRLname+'.json')
            with open(dpath, 'w') as f:
                json.dump(regt[i], f) # 经缩放后真实地标点坐标(re ground truth)
            DRLcount += 1
    return namelist,DRL_list,DRLcount # namelist:文件名称列表(图片数量),DRL_list:所有图片的车道线(可能重复),DRLcount:对DRL_list进行计数。


def getInfo(ipath,data): # 传入图片和json文件中的对应行
    img = Image.open(ipath)
    img = np.array(img)
    img_t = np.zeros(img.shape)
    mask = np.zeros((img.shape[0], img.shape[1]))
    mask_color = np.zeros((img.shape[0], img.shape[1]))
    gt_lanes_vis = [[(x, y) for (x, y) in zip(lane, data['h_samples'])if x >= 0] for lane in data['lanes']] # 存储每条车道的坐标信息

# 依每条车道的坐标画图线(如图7的黑白图片)
    for lane in gt_lanes_vis:
        cv2.polylines(img_t, np.int32(
            [lane]), isClosed=False, color=(0, 255, 0), thickness=5) # img_t存贮车道线坐标(拟合一条线)
    mask_color[img_t[:, :, 1] == 255] = 255 # 生成掩码(有车道线的地方置为255,没有则是0)
    mask[img_t[:, :, 1] == 255] = 1  # 生成掩码(有车道线的地方置为1,没有则是0)
    # 计算bbox
    temp = Image.fromarray(mask.astype('uint8')) # 把mask转换为图像格式
    temp = np.array(temp) # 图像 到 矩阵
    bbox = []
    box = []
    box_mask = []
    gt = []
    rebox = []
    rebox_mask = []
    regt = []
    temp, cons, hier = cv2.findContours(temp, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) # 需注意不同版本opencv返回值的不同。**https://www.cnblogs.com/guobin-/p/10842486.html**    博主在这里卡了一天,汇报又要挨骂了。。。OpenCV2和OpenCV4中是两个返回值,OpenCV3中有三个返回值
    for con in cons: # 对轮廓的点集进行遍历
        x, y, w, h = cv2.boundingRect(con) # 边界框
        # 截取原图
        box.append(img[y:y+h, x:x+w, :].astype('uint8'))
        # 截取mask
        temp = np.zeros(img.shape) # temp是掩码
        cv2.drawContours(temp, [con], 0, (0, 0, 255), -1) # 在temp上画上边界框
        ttemp = np.zeros((img.shape[0], img.shape[1])) 
        ttemp[temp[:, :, 2] == 255] = 1# temp有边界框地方置为1,其他地方为0
        box_mask.append(ttemp[y:y+h, x:x+w].astype('uint8')) # 截取边界框区域
        # 计算bbox [{'class':cl,'points',[x1,y1,x2,y2]},{}]
        temp = ttemp[y:y+h, x:x+w]
        [vx, vy, xx, yy] = cv2.fitLine(con, cv2.DIST_L2, 0, 0.01, 0.01) # https://blog.csdn.net/lovetaozibaby/article/details/99482973         前两维代表拟合出的直线的方向,后两位代表直线上的一点。(即通常说的点斜式直线)         (vx,vy)是与直线共线的标准化向量(x,y)是线上的一个点
        slope = -float(vy)/float(vx)
        if slope <= 0:
            # 左上 到右下 (右侧)
            cl = 0
        else:
            # 左下到右上(左侧)
            cl = 1
        ttemp = dict()
        ttemp['points'] = [int(x), int(y), int(x+w), int(y+h)]
        ttemp['class'] = cl
        bbox.append(ttemp) # bbox存贮边界盒
        # 计算gt[{'class':cl,'gt':[x1,x2,x3,x4,x5]},{}]
        ttemp = dict()
        ttemp['class'] = cl
        initY = []
        for i in range(5):
            initY.append(int((i+1)*(h/6))) # 把五个纵坐标找到
        initX = []
        for y in initY: # ground truth的定义见文章5.2节
            xx = temp[y, :]
            xx = np.where(xx == 1)
            x = int((np.max(xx)+np.min(xx))/2)
            initX.append(x)
        ttemp['gt'] = initX
        gt.append(ttemp) # gt存储真实五个点X信息
    # 生成resize的DRL材料**(我认为是把图片给缩放成100x100尺寸【rebox】,掩码也缩放【rebox_mask】,也把ground truth点缩放【regt】)**
    for i in range(len(box)):
        temp = box[i].copy() # box存贮经截取了的图片(一张图片中几条车道线几个box)
        temp = cv2.resize(temp, (100, 100))
        rebox.append(temp)

        temp = box_mask[i].copy()
        temp = cv2.resize(temp, (100, 100))
        rebox_mask.append(temp)
        # pdb.set_trace()

        ttemp = dict()
        ttemp['class'] = gt[i]['class']
        initY = [11, 31, 51, 71, 91]
        initX = []
        for y in initY:
            xx = temp[y, :]
            xx = np.where(xx == 1)
            x = int((np.max(xx)+np.min(xx))/2)
            initX.append(x)
        ttemp['gt'] = initX
        regt.append(ttemp)
    result = [mask, mask_color, bbox, box, box_mask, gt, rebox, rebox_mask, regt]
    return result # 最终的返回值:mask(黑色掩码【背景】),mask_color(白色掩码【车道线】),bbox(边界盒【四个点坐标+类别),box(经截取了的图片),box_mask(截取了的图片的掩码,黑色背景),gt(真实地标点坐标),rebox(经缩放的边界盒),rebox_mask(经缩放的黑色背景掩码),regt(经缩放的真实地标点坐标)

DATAROOT = r"/home/wqf/tusimple" # 你自己下载的tusimple数据集目录所在

# 测试数据集
sroot = 'train_set' 
jlistpath = r'train_set/label_data_0531.json'
droot = 'MyTuSimpleLane/train'
sroot = osp.join(DATAROOT, sroot) # 源目录
jlistpath = osp.join(DATAROOT, jlistpath) # json文件目录
droot = osp.join(DATAROOT,droot) # 目标目录
namelist, DRL_list,DRLcount = doit(sroot, jlistpath, droot, namelist=[], DRL_list=[], DRLcount=0)
print(len(faileList))

sroot = 'train_set'
jlistpath = r'train_set/label_data_0313.json'
droot = 'MyTuSimpleLane/train'
sroot = osp.join(DATAROOT, sroot)
jlistpath = osp.join(DATAROOT, jlistpath)
droot = osp.join(DATAROOT,droot)
namelist, DRL_list, DRLcount = doit(sroot, jlistpath, droot, namelist=namelist, DRL_list=DRL_list, DRLcount=DRLcount)
print(len(faileList))

sroot = 'train_set'
jlistpath = r'train_set/label_data_0601.json'
droot = 'MyTuSimpleLane/train'
sroot = osp.join(DATAROOT, sroot)
jlistpath = osp.join(DATAROOT, jlistpath)
droot = osp.join(DATAROOT,droot)
namelist, DRL_list, DRLcount = doit(sroot, jlistpath, droot, namelist=namelist, DRL_list=DRL_list, DRLcount=DRLcount)
print(len(faileList))
with open(osp.join(DATAROOT,'train_img_list.json'), 'w') as f:
    json.dump(namelist, f) # 文件名称列表(图片数量)我这里是3626张

with open(osp.join(DATAROOT,'train_DRL_list.json'), 'w') as f:
    json.dump(DRL_list,f)  # 所有图片的车道线(可能重复)我这里是13704张

# 训练数据集
sroot = 'test_set'
jlistpath = r'test_label.json'
droot = 'MyTuSimpleLane/test'
sroot = osp.join(DATAROOT, sroot)
jlistpath = osp.join(DATAROOT, jlistpath)
droot = osp.join(DATAROOT,droot)
namelist, DRL_list, DRLcount = doit(sroot, jlistpath, droot, namelist=[], DRL_list=[], DRLcount=0)
print(len(faileList))
with open(osp.join(DATAROOT,'test_img_list.json'), 'w') as f:
    json.dump(namelist, f)

with open(osp.join(DATAROOT,'test_DRL_list.json'), 'w') as f:
    json.dump(DRL_list,f)

# 存储所有无效的数据信息
with open(osp.join(DATAROOT,'failList.json'), 'w') as f:
    json.dump(faileList, f)

注:上述方法保存的文件都是一行

  1. 修改变量DATAROOT(文件genMeanImg.py,第22行)到您使用的实际$DATAPATH(这里是r“/opt/disk/zzy/dataset/TuSimple”)。
DATAROOT = r"/opt/disk/zzy/dataset/TuSimple"
  1. 转到$CODEPATH并运行genMeanImg.py,等待其完成。
cd $CODEPATH
python genMeanImg.py

genMeanImg.py

# coding=utf-8
# [email protected]
# WIN10 Python3.6.6
# 计算 裁剪 resize 之后的TuSimple数据的meanImage
import os.path as osp
import json
from PIL import Image
import numpy as np
from tqdm import tqdm


def genMeanImg(jsonListPath, root): # 传入json文件和路径
    with open(jsonListPath, 'r') as f:
        jsonList = json.load(f) # 13704个list
    meanImg = np.array(Image.open(osp.join(root,jsonList[0]+'.png')))
    for name in tqdm(jsonList):
        temp = osp.join(root, name + '.png')
        img = np.array(Image.open(temp))
        meanImg = (meanImg + img) / 2
    return meanImg

DATAROOT = r"/opt/disk/zzy/dataset/TuSimple"

jsonListPath = osp.join(DATAROOT,'train_DRL_list.json')
root = osp.join(DATAROOT,'MyTuSimpleLane/train/DRL/resize/')
meanImg = genMeanImg(jsonListPath, root)
print(meanImg.shape)
print(meanImg)
savePath = osp.join(DATAROOT,'meanImgTemp.npy')
np.save(savePath,meanImg)
 
 # 将100x100的所有图像加在一起取均值,是为了后面的归一化吧
  • 检查一下。完成以上所有操作后,$DATAPATH的文件夹树应该如下所示:
    基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第21张图片
    命令行输入 tree -L 3确实如此。
    基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第22张图片

训练模型

修改config.py文件

  • 修改__C.DATAROOT变量(config.py文件,39行)为你的真实$DATAPATH(这里是r"/opt/disk/zzy/dataset/TuSimple"
__C.DATAROOT = r'/opt/disk/zzy/datasets/TuSimpleLane'
  • 根据你的电脑配置修改GPU设置(config.py文件,26, 27, 28行)
__C.TRAIN.USE_GPU = True
if __C.TRAIN.USE_GPU:
    __C.TRAIN.GPU_ID = [2, 3]
  • 其它参数设置请可自行修改config.py文件

config.py

# coding=utf-8
# [email protected]
# WIN10 Python3.6.6
# 用途: DRL_Lane Pytorch 实现
# config.py
import os
import os.path as osp
import time
from easydict import EasyDict as edict #可以使得以属性的方式去访问字典的值 https://blog.csdn.net/m0_38082419/article/details/79079516

# 本文件是超参数的配置

# init
__C = edict() # 实例化
cfg = __C
__C.EXP = edict()
__C.DATA = edict()
__C.TRAIN = edict()
__C.TEST =edict() # 构建字典

# train*************************train
__C.TRAIN.LR = 1e-4
__C.TRAIN.WEIGHT_DECAY = 5e-4
__C.TRAIN.MAX_EPOCH = 100
# 每个buffer 训练的epoch数量
__C.TRAIN.INER_EPOCH = 10
# GPU 设置
__C.TRAIN.USE_GPU = True
if __C.TRAIN.USE_GPU:
    __C.TRAIN.GPU_ID = [0] # 我对GPU的配置不是很了解,要学习
# 断点续训
__C.TRAIN.RESUME = False
__C.TRAIN.RESUME_PATH = '20-04-18-10-30_TuSimpleLane/EP_9_HitRat0.60316.pth'

# test
__C.TEST.BS = 1

# data*************************data
__C.DATA.NAME = 'TuSimpleLane'  # SelfLane/TuSimpleLane
if __C.DATA.NAME == 'TuSimpleLane':
    __C.DATAROOT = r'/opt/disk/zzy/datasets/TuSimpleLane'
    __C.DATA.TRAIN_LIST = osp.join(__C.DATAROOT,'train_DRL_list.json')
    __C.DATA.VAL_LIST = osp.join(__C.DATAROOT,'test_DRL_list.json')
    __C.DATA.ROOT = osp.join(__C.DATAROOT,'MyTuSimpleLane')
    # meanImagePath
    __C.DATA.MEAN_IMG_PATH = osp.join(__C.DATAROOT,r'meanImgTemp.npy')
if __C.DATA.NAME == 'SelfLane':
    __C.DATA.TRAIN_LIST =''#  TODO:
    __C.DATA.VAL_LIST = ''  #  TODO:
    # meanImagePath
    __C.DATA.MEAN_IMG_PATH =r''#TODO:
# buffer 的dataloader的设置
__C.DATA.NUM_WORKS = 8
__C.DATA.BS = 2048
__C.DATA.SHUFFLE = True
# img dataloder的设置
__C.DATA.TRAIN_IMGBS = 100  #  TODO:
__C.DATA.VAL_IMGBS =1#  TODO:
__C.DATA.IMGSHUFFLE = True


# DQL*************************DQL
# 最大步数
__C.MAX_STEP = 10
# 距离阈值
__C.DST_THR = 5
# action 数量 确定为4
__C.ACT_NUM = 4
# History数量
__C.HIS_NUM = 8
# epsilon
__C.EPSILON = 1
# gamma
__C.GAMMA = 0.90
# landmark 数量
__C.LANDMARK_NUM = 5
# reward
__C.reward_terminal_action = 3
__C.reward_movement_action = 1
__C.reward_invalid_movement_action = -5
__C.reward_remove_action = 1
# buffer capacity
__C.BUFFER_CAP = 20480*5


# exp*************************exp
__C.SEED = 666 # 随机数种子
__C.EXP.ROOT = 'exp'
now = time.strftime("%y-%m-%d-%H-%M", time.localtime()) # 当下时间。strftime:格式化一个时间字符串
__C.EXP.NAME = now+'_'+__C.DATA.NAME # 每次运行都会生成一个不同的__C
__C.EXP.PATH = os.path.join(__C.EXP.ROOT,__C.EXP.NAME)

训练模型

cd $CODEPATH
python train.py

train.py

# coding=utf-8
# [email protected]
# WIN10 Python3.6.6
# 用途: DRL_Lane Pytorch 实现
# train.py
import os
import pdb
import torch
import scipy
import random
import collections
import numpy as np
import os.path as osp
from tqdm import tqdm
from torchvision import transforms
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from PIL import Image

from config import cfg # 导入自己的模块,一些超参
import utils # 存放一些通用的工具
import datasets # 载入 self_lane数据集
import model
import reward
from utils import Timer

class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        os.makedirs(self.cfg.EXP.PATH, exist_ok=True)
        os.makedirs(self.cfg.EXP.PATH+'/valimg', exist_ok=True) # 根据生成的时间创建目录
        # logger
        self.writer = SummaryWriter(self.cfg.EXP.PATH) # 实例化一个SummaryWriter
        # 计时器
        self.t = {'iter': Timer(), 'train': Timer(), 'val': Timer()} # 实例化三个计时器
        # 保存实验环境 # TODO: 启用
        temp = os.path.join(self.cfg.EXP.PATH, 'code') # 'exp/20-12-23-09-04_TuSimpleLane/code'
        utils.copy_cur_env('./', temp, exception='exp')
        # 读取数据集
        self.meanImg, self.trainloader, self.valloader = datasets.getData(self.cfg) # 
        # 定义网络
        self.net = model.getModel(cfg) 
        # 损失函数
        self.criterion = torch.nn.MSELoss()
        # 优化器
        self.optimizer = torch.optim.Adam(
            self.net.parameters(),
            lr=self.cfg.TRAIN.LR, #  1e-4
            weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) # 5e-4
        # 初始化一些变量
        self.beginEpoch = 1
        self.batch = 1
        self.bestacc = 0
        # 载入预训练模型
        if self.cfg.TRAIN.RESUME: # 断点续训,默认False
            print('Loading Model..........')
            saved_state = torch.load(self.cfg.TRAIN.RESUME_PATH)
            self.net.load_state_dict(saved_state['weights'])
            self.beginEpoch = saved_state['epoch']
            self.batch = saved_state['batch']
            self.bestacc = saved_state['bestacc']
        # GPU设定
        self.gpu = torch.cuda.is_available() and self.cfg.TRAIN.USE_GPU # True
        self.device = 'cuda' if self.gpu else 'cpu' # cuda
        if self.gpu:
            torch.cuda.set_device(self.cfg.TRAIN.GPU_ID[0])
            self.criterion.cuda() # 把损失函数放到GPU上计算
            if len(self.cfg.TRAIN.GPU_ID) > 1: # torch.cuda.device_count():返回gpu数量。我只有一块,所以无法并行(不太懂)
                self.net = torch.nn.DataParallel(
                    self.net, device_ids=self.cfg.TRAIN.GPU_ID) # 并行计算
            self.net = self.net.cuda()
        else: # 没有gpu,只能在cpu上计算
            self.net = self.net.cpu()
            self.criterion.cpu()

    def train(self):
        for self.epoch in range(self.beginEpoch, self.cfg.TRAIN.MAX_EPOCH):
            # 训练一个Epoch
            self.t['train'].tic() # 开始给train计时
            self.trainEpoch() # 训练一个Epoch
            temp = self.t['train'].toc(average=False) # 不平均
            print('Train time of Epoch {} is : {:.2f}s'.format(self.epoch, temp)) # 训练第几个epoch,用时多少
            # 在验证集上测试
            self.t['val'].tic()
            acc = self.val() # 准确率
            temp = self.t['val'].toc(average=False) # 
            print('Val time of/after Epoch {} is : {:.2f}s'.format(self.epoch, temp)) # 同上(测试第几个epoch,用时多少)
            print('Acc for Epoch {} is : {:.4f}'.format(self.epoch, acc)) # 第几个epoch的准确率是多少
            print('BestAcc is:{:.4f}'.format(self.bestacc)) # 最佳准确率
            self.writer.add_scalar('ValHitRate_PerEpoch', acc, self.epoch) # acc 随 epoch的变化
            # 保存模型
            if acc > self.bestacc: # 保存更好的准确率的模型
                self.bestacc = acc
                temp = "HitRat{:.5f}".format(acc)
                self.save(temp)


    def trainEpoch(self):# 训练一个Epoch
        self.buffer = utils.ExBuffer(self.cfg.BUFFER_CAP) # 实例化一个buffer。cap:102400 replaceInd=0 replaceMax:102400 是个队列
        self.buffer.clean() # 如果不是第一次训练,要把buffer里面的东西给清空
        print('Build Buffer.........')
        for batch_index, (clas, imgs, gts) in tqdm(enumerate(self.trainloader)): # 关于这里的函数调用不太清楚。运行此行,先到datasets.py的__len__方法,返回self.pathList的长度(13704),之后又到datasets.py的__getitem__方法,处理一个batch(100张)图片(该方法的index是由config.py的cfg和utils.py的setup_seed方法共同决定的)batch的大小是由datasets.py中用DataLoader方法载入数据时决定的。
            clas = clas.numpy() # 数据类型的转换。https://blog.csdn.net/u012177700/article/details/106984537/
            imgs = imgs.numpy()
            gts = gts.numpy() # opencv中坐标轴 https://blog.csdn.net/weixin_43124720/article/details/88856234
            for j in range(len(imgs)): 
                self.img = imgs[j]-self.meanImg # 每个图像都减去均值
                self.cl = clas[j]
                self.gt = gts[j] # gt存贮真实点x坐标。从小到大
                if self.cl == 1: # 左下到右上(在genMyData.py里有说明)y是不变的,预测x
                    self.initMarkX = [91.0, 71.0, 51.0, 31.0, 11.0] # 初始点坐标
                else: # 左上到右下
                    self.initMarkX = [11.0, 31.0, 51.0, 71.0, 91.0]
                self.updateBuffer() # 
                if self.trainFlag: # 开始训练情形一:buffer里的数据够了
                    self.trainBuffer()
                    print('Build Buffer.........')
        self.trainBuffer() # 开始训练情形一:trainloader里的数据读完了

    def trainBuffer(self):
        print('Training..........')
        self.net.train() # 告诉网络在训练
        tf = transforms.ToTensor() # 将numpy的ndarray或PIL.Image读的图片转换成形状为(C,H, W)的Tensor格式,且/255归一化到[0,1.0]之间 https://www.pianshen.com/article/6972192583/
        dataset = datasets.bufferLoader(self.buffer.buffer,tf=tf) # 载入buffer数据
        loader = DataLoader(dataset, num_workers=self.cfg.DATA.NUM_WORKS, batch_size=self.cfg.DATA.BS, shuffle=self.cfg.DATA.SHUFFLE) # 载入数据 num_workers:线程数,8 batch_size:1024(2048时GPU内存不够用,故该) shuffle:训练时一般都为ture
        for epoch in tqdm(range(self.cfg.TRAIN.INER_EPOCH)): # 一共10个epoch
            for fea, state, Q in loader:  # fea:未编码特征向量【1024,3,20,100】 state:当下点坐标+动作历史向量【1024,33】 Q:Q表【1024,4】
                fea, state, Q = fea.to(self.device), state.to(self.device), Q.to(self.device) # 数据放到GPU上
                self.optimizer.zero_grad() # 梯度为0
                output = self.net(fea, state) # 神经网络拟合Q表【1024,4】
                loss = self.criterion(output, Q) # 计算loss
                loss.backward() # 误差 反向传播
                self.optimizer.step() # 梯度更新
                self.writer.add_scalar('trian_loss', loss.item(), self.batch) # 横轴:第几个batch;纵轴:loss.item():得到一个元素张量里面的元素值,用于将一个零维张量转换成浮点数
                self.batch += 1 # batch自增

    def val(self):
        self.net.eval() # 告诉网络在预测
        hit_cnt = 0 # (第二阶段)检测对的点(相当于5.1中的Hit point Nh)
        detect_hit_cnt = 0 # (第一阶段)检测对的点
        test_cnt = 0 # 总点数(相当于5.1中的 N)
        sup_cnt = 0 # 第一阶段比第二阶段效果好的点
        steps_cnt = 0 # 本测试周期内所有点的总步数
        for valIndex, (cl, img, gt) in tqdm(enumerate(self.valloader)):
            img = np.squeeze(img.numpy())
            cl = np.squeeze(cl.numpy())
            gt = np.squeeze(gt.numpy()) # 转为ndarray格式,并且去除多余纬度
            img = img - self.meanImg # 图像减去平均值
            if cl == 1:
                initMarkX = [91.0, 71.0, 51.0, 31.0, 11.0]
            else:
                initMarkX = [11.0, 31.0, 51.0, 71.0, 91.0]
            # 循环处理五个landmark point
            xpoints = dict() # 创建字典
            for k in np.arange(self.cfg.LANDMARK_NUM, 0, -1):  # 5.4.3.2.1
                cur_x = [] # 记录点的运动轨迹
                step = 0
                allActList = np.zeros(self.cfg.MAX_STEP) # 动作的历史
                status = 1 # 点是活跃的
                if gt[k - 1] == -1: # 无效点
                    gt[k - 1] = -20
                gt_point = gt[k - 1]
                fea_t = np.array(img[(k - 1) * 20:k * 20,:,:])
                fea_t = np.transpose(fea_t, (2, 0, 1))
                fea_t = fea_t.astype(np.float32)
                fea_t = fea_t.reshape((1,fea_t.shape[0],fea_t.shape[1],fea_t.shape[2]))
                fea_t = torch.from_numpy(fea_t).cuda() # 之前解释过,不赘述
                cur_point = initMarkX[k - 1] # 当下点坐标
                cur_x.append(cur_point) # 记录点的运动轨迹
                if self.cfg.HIS_NUM == 0:
                    hist_vec = []
                else:
                    hist_vec = np.zeros([self.cfg.ACT_NUM * self.cfg.HIS_NUM]) # 动作历史向量
                state = reward.get_state(cur_point, hist_vec) # 得到状态s
                while (status == 1) & (step < self.cfg.MAX_STEP):
                    step += 1 # 步数+1
                    state= state.astype(np.float32).reshape((1,-1))
                    state = torch.from_numpy(state).cuda() # 同
                    qval = np.squeeze(self.net(fea_t, state).detach().cpu().numpy()) # 神经网络拟合Q表。ndarray格式
                    action = (np.argmax(qval)) + 1 # 根据Q表选择动作
                    allActList[step - 1] = action # 记录动作
                    if action != 4: # 根据动作做出行动
                        if action == 1:
                            cur_point = -20
                        elif action == 2:
                            cur_point -= 5
                        elif action == 3:
                            cur_point += 5
                        cur_x.append(cur_point) # 记录点的运动轨迹
                    else:
                        status = 0 # 点已经terminal
                    if self.cfg.HIS_NUM != 0:
                        hist_vec = reward.update_history_vector(
                            hist_vec, action) # 更新动作历史向量
                    state = reward.get_state(cur_point, hist_vec) # 得到状态s'
                steps_cnt += step # 本测试周期内所有点的总步数
                finalPoint = cur_point # 最终点坐标
                finalDist = abs(finalPoint - gt_point) # (第二阶段)|预测点-真实点|
                det_dst = abs(initMarkX[k-1]-gt_point) #(第一阶段)|检测点-真实点|
                if det_dst < self.cfg.DST_THR:
                    detect_hit_cnt += 1 # 第一阶段检测对的点
                test_cnt += 1 # 总点数
                if finalDist <= self.cfg.DST_THR:
                    hit_cnt += 1 # 第二阶段检测对的点
                if finalDist <= det_dst:
                    sup_cnt += 1 # 第二阶段比第一阶段效果好的点
                xpoints[str(k-1)] = cur_x # 记录一张图中所有点和对应运动过程
        finImg = utils.visOneLane(img, self.meanImg, gt, initMarkX, xpoints) # 可视化一条车道线【560,1232,3】
        finImg = utils.catFinalImg(finImg) # 拼接那些图
        tempPath = osp.join(self.cfg.EXP.PATH,'valimg','val'+str(self.epoch)+'.png') # 路径
        Image.fromarray(finImg.astype('uint8')).save(tempPath) # 保存
        finImg=np.transpose(finImg, (2,0,1)) # 改变H,W,C的顺序 【3,560,1232】
        self.writer.add_image('Val_Vis',finImg,self.epoch) # 向wirter中添加图像。横轴:epoch 纵轴:图像
        self.writer.add_scalar('Val_RL_HR', float(hit_cnt) / test_cnt, self.epoch) # 第二阶段检测对的点/总点数(Nh/N)hit rate阿尔法 【阿尔法 随epoch的变化情况】
        self.writer.add_scalar('Val_Hit_Cnt',hit_cnt,self.epoch) # 第二阶段检测对的点 随epoch的变化情况
        self.writer.add_scalar('Val_Det_HR', float(detect_hit_cnt) / test_cnt, self.epoch) # 第一阶段检测对的点/总点数
        self.writer.add_scalar('Val_Det_Hit_Cnt',detect_hit_cnt,self.epoch) # 第一阶段检测对的点 随epoch的变化情况
        self.writer.add_scalar('Val_RLsupDet_HR', float(sup_cnt)/test_cnt, self.epoch) # 第二阶段比第一阶段效果好多少 随epoch的变化情况
        self.writer.add_scalar('Val_Average_Step', float(steps_cnt) / ((valIndex + 1) * 5), self.epoch) # 本测试周期内所有点的总步数/所有地标点个数 随epoch的变化情况贝塔
        return float(hit_cnt) / test_cnt # 返回 最终的阿尔法

    def updateBuffer(self): # 更新buffer
    '''
    函数顺序:
    for 5个点中依次选1点
        
    
    '''
        self.trainFlag = False
        buf = collections.namedtuple('buf', field_names=['fea', 'state', 'Q']) # namedtuple('名称', [属性list])。增加可读性。属于Special Variables
        # generateExpReplay
        for k in np.arange(cfg.LANDMARK_NUM, 0, -1):  # [5,4,3,2,1]。对每个点依次进行。这个写法可以借鉴
            if self.gt[k - 1] == -1: # 如果是-1
                self.gt[k - 1] = -20 # 赋值为20
            gt_point = self.gt[k - 1] # 直接赋值
            # generate actions
            # status indicates whether the agent is still alive and has not triggered the terminal action
            status = 1 # Agent仍然在运动
            step = 0 # 记录步数
            cur_point = self.initMarkX[k - 1] # 当下点x坐标
            landmark_fea = np.array(self.img[(k - 1) * 20:k * 20, :, :]) # 截取100x100中的一个片段(【80:100,100,3】
            landmark_fea_trans = np.reshape(landmark_fea, (1, 20, 100, 3)) # reshape
            if self.cfg.HIS_NUM == 0: # 加上动作历史向量
                hist_vec = []
            else:
                hist_vec = np.zeros([self.cfg.HIS_NUM*self.cfg.ACT_NUM]) # 32。one-hot编码
            state = reward.get_state(cur_point, hist_vec) # 状态33(cur_point:1 hist_vec:32
            cur_dst = reward.get_dst(gt_point, cur_point) # 两个坐标差的绝对值
            last_point = cur_point # 最终点 
            last_dst = cur_dst # 最终差值
            while (status == 1) & (step < self.cfg.MAX_STEP):
                rew = []
                qval = np.array(self.predict(landmark_fea, state)) # 数据类型 ndarray,(4,)
                step += 1 # 相当于走了一步
                # 挑选action 计算reward
                # we force terminal action in case actual IoU is higher than 0.5, to train faster the agent
                # 当实际IoU高于0.5时,我们强制终端操作,以更快地训练Agent
                if cur_dst < self.cfg.DST_THR: # |当下坐标-真实坐标|<许用值(5)
                    action = 4 # 选择动作4
                # epsilon-greedy policy
                elif random.random() < self.cfg.EPSILON: # 否则就采用epsilon-greedy策略
                    action = np.random.randint(1, 5)
                else:
                    action = (np.argmax(qval)) + 1
                # terminal action
                if action == 4: # 动作4是终止动作
                    rew = reward.get_reward_trigger(cur_dst) # 奖励3
                # move action,performe the crop of the corresponding subregion
                # 移动动作,执行相应子区域的裁剪
                elif action == 1:
                    cur_point = -20
                    cur_dst = reward.get_dst(gt_point, cur_point)
                    rew = reward.getRewRm(cur_dst)
                    last_dst = cur_dst
                    last_point = cur_point
                elif action == 2:  # to left
                    cur_point = cur_point - 5
                    cur_dst = reward.get_dst(gt_point, cur_point)
                    rew = reward.getRewMov0427(cur_point, last_point, gt_point)
                    last_dst = cur_dst
                    last_point = cur_point
                elif action == 3:  # to right
                    cur_point = cur_point + 5
                    cur_dst = reward.get_dst(gt_point, cur_point)
                    rew = reward.getRewMov0427(cur_point, last_point, gt_point)
                    last_dst = cur_dst
                    last_point = cur_point
                if self.cfg.HIS_NUM != 0:
                    hist_vec = reward.update_history_vector(hist_vec, action) # 更新动作历史向量
                new_state = reward.get_state(cur_point, hist_vec) # [91.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.,  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.] 【当下点位置,动作历史向量】
                # 计算 用来训练的Q值
                if action == 4:
                    temp = rew # 3
                else:
                    temp = np.array(self.predict(landmark_fea, new_state))
                    temp = np.argmax(temp)
                    temp = rew + self.cfg.GAMMA * temp
                qval[action-1] = temp # 更新Q表 [-1.3242564 -0.7344132  2.346577   3.       ]
                # 将数据存入buffer
                if self.buffer.ready2train: # 1.跳转utils.py的ready2train方法 2.再跳转utils.py的isFull方法 发现没有准备好去训练
                    self.trainFlag = True
                    break
                else:
                    temp = buf(landmark_fea, state, qval) # landmark_fea:【20,100,3】 state:还是旧的[91.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.,  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.] qval:Q表[-1.3242564 -0.7344132  2.346577   3.       ]
                    self.buffer.append(temp) # buffer里存储了上述三个信息。【landmark_fea:图像特征(未编码)】【state:当下点位置+动作历史向量】【qval:Q表】
                if action == 4: 
                    status = 0 # 终止动作,该点的状态设为0
                state = new_state # 更新状态

    def predict(self, fea, sta):
        self.net.eval() # 好像是告诉网络在预测。https://blog.csdn.net/edward_zcl/article/details/101947941
        fea = np.transpose(fea, (2, 0, 1)) # 【3,20,100】
        fea = fea.astype(np.float32)
        fea = fea.reshape((1,fea.shape[0],fea.shape[1],fea.shape[2]))  #【1,3,20,100】
        fea = torch.from_numpy(fea).cuda() # ndarray 到 tensor。 同时放到GPU上
        sta = sta.astype(np.float32)
        sta = sta.reshape((1,-1)) # 【1, 33】
        sta = torch.from_numpy(sta).cuda() # ndarray 到 tensor。 同时放到GPU上

        x = self.net(fea, sta) # 都是一个tensor。4
        return np.squeeze(x.data.detach().cpu().numpy()) # https://blog.csdn.net/qq_39938666/article/details/90794240。返回数据类型:ndarray,(4,)

    def save(self,temp): # 保存模型
        temp = 'EP_'+str(self.epoch)+'_'+temp+'.pth' # 名称
        path = osp.join(self.cfg.EXP.PATH, temp) # 路径
        if (not self.cfg.TRAIN.USE_GPU) or (len(self.cfg.TRAIN.GPU_ID) == 1): # 不用GPU或者只用一块GPU
            to_saved_weight = self.net.state_dict()
        else:
            to_saved_weight = self.net.module.state_dict() # 多块GPU(可能要分开保存不同GPU上模型的参数吧) 这二者区别还不太懂
        toSave = {
            'weights': to_saved_weight,
            'epoch': self.epoch,
            'batch': self.batch,
            'bestacc': self.bestacc
        }
        torch.save(toSave, path) # 保存
        print('Model Saved!')


if __name__ == "__main__":
    utils.setup_seed(cfg.SEED) # cfg:{'EXP': {'ROOT': 'exp', 'NAME': '20-12-22-21-54_TuSimpleLane', 'PATH': 'exp/20-12-22-21-54_TuSimpleLane'}, 'DATA': {'NAME': 'TuSimpleLane', 'TRAIN_LIST': '/home/wqf/tusimple/train_DRL_list.json', 'VAL_LIST': '/home/wqf/tusimple/test_DRL_list.json', 'ROOT': '/home/wqf/tusimple/MyTuSimpleLane', 'MEAN_IMG_PATH': '/home/wqf/tusimple/meanImgTemp.npy', 'NUM_WORKS': 8, 'BS': 2048, 'SHUFFLE': True, 'TRAIN_IMGBS': 100, 'VAL_IMGBS': 1, 'IMGSHUFFLE': True}, 'TRAIN': {'LR': 0.0001, 'WEIGHT_DECAY': 0.0005, 'MAX_EPOCH': 100, 'INER_EPOCH': 10, 'USE_GPU': True, 'GPU_ID': [2, 3], 'RESUME': False, 'RESUME_PATH': '20-04-18-10-30_TuSimpleLane/EP_9_HitRat0.60316.pth'}, 'TEST': {'BS': 1}, 'DATAROOT': '/home/wqf/tusimple', 'MAX_STEP': 10, 'DST_THR': 5, 'ACT_NUM': 4, 'HIS_NUM': 8, 'EPSILON': 1, 'GAMMA': 0.9, 'LANDMARK_NUM': 5, 'reward_terminal_action': 3, 'reward_movement_action': 1, 'reward_invalid_movement_action': -5, 'reward_remove_action': 1, 'BUFFER_CAP': 102400, 'SEED': 666}
    MyTrainer = trainer(cfg) # 实例化
    MyTrainer.train() # 调用train方法

第一个epoch完成后在测试集上最后一张车道线的表现如下:
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第23张图片
第25个epoch
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第24张图片
第50个epoch
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第25张图片
第100个epoch
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第26张图片

utils.py

utils.py文件解读

# coding=utf-8
# ----tuzixini----
# WIN10 Python3.6.6
# tools/utils.py
'''
存放一些通用的工具
'''
import os
import cv2
import pdb
import time
import copy
import torch
import shutil
import random
import collections
import numpy as np


def setup_seed(seed): # 生成随机数种子
    torch.manual_seed(seed) # 为CPU设置
    torch.cuda.manual_seed_all(seed) # 为 GPU 设置
    np.random.seed(seed) 
    random.seed(seed)
    torch.backends.cudnn.deterministic = True # 提高效率?待学习 https://blog.csdn.net/yao1249736473/article/details/102584126


class Timer(object):
    """A simple timer."""

    def __init__(self):
        self.clean() # 初始化时先清零

    def tic(self):
        # using time.time instead of time.clock because time time.clock
        # does not normalize for multithreading
        self.start_time = time.time()

    def toc(self, average=True):
        self.diff = time.time() - self.start_time
        self.total_time += self.diff
        self.calls += 1
        self.average_time = self.total_time / self.calls
        if average:
            return self.average_time
        else:
            return self.diff

    def clean(self):
        self.total_time = 0.
        self.calls = 0
        self.start_time = 0.
        self.diff = 0.
        self.average_time = 0.


def copy_cur_env(work_dir, dst_dir, exception='exp'): # work_dir:'./' dst_dir:'exp/20-12-23-09-04_TuSimpleLane/code' exception:'exp'
    # 复制本次运行的工作环境,排除部分文件夹(exception='exp'不会被复制过去,其他均被复制)
    # 关于'.gitignore'文件:https://www.cnblogs.com/yulinlewis/p/10231035.html
    if not os.path.exists(dst_dir):
        os.mkdir(dst_dir)
    for filename in os.listdir(work_dir):
        file = os.path.join(work_dir, filename)
        dst_file = os.path.join(dst_dir, filename)
        if os.path.isdir(file) and exception not in filename: # os.path.isdir()用于判断对象是否为一个目录 file:'./README.md'
            shutil.copytree(file, dst_file) # shutil.copytree() 模块具体用法,可以递归copy多个目录到指定目录下。
        elif os.path.isfile(file): # os.path.isfile()用于判断对象是否为一个文件
            shutil.copyfile(file, dst_file) # shutil.copyfile(file1,file2)。file1为需要复制的源文件的文件路径,file2为目标文件的文件路径+文件名.


class ExBuffer(object): # DQN的buffer
    def __init__(self, capacity, flashRat=1):
        self.cap = capacity # 容量
        self.replaceInd = 0 
        self.replaceMax = int(capacity*flashRat)
        self.buffer = collections.deque(maxlen=capacity) # 队列

    def append(self, exp): # 向buffer内存贮
        self.buffer.append(exp) # 
        self.replaceInd += 1 # 每存储一条便+1

    @property # 使用@property装饰器来创建只读属性,@property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改。
    def isFull(self): # 判断buffer是否存满
        if len(self.buffer) < self.cap:
            return False
        else:
            return True

    @property
    def ready2train(self):# 当replaceInd=replaceMax时,可以训练
        if self.isFull:
            if self.replaceInd < self.replaceMax:
                return False
            else:
                self.replaceInd = 0
                return True
        else:
            return False
    

    def samlpe(self, BS): # 采样。这里可能要将BS改为batch_size,不然报错
        indices = np.random.choice(len(self.buffer), batch_size, replace=False) # 记录下标
        states, actions, rewards, dones, next_states = zip(
            *[self.buffer[idx] for idx in indices]) # 存储一个batch_size大小
        return np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), \
            np.array(dones, dtype=np.uint8), np.array(next_states) # 返回所有

    def clean(self): # 清理
        self.buffer.clear()
        self.replaceInd = 0


def visOneLane(img, meanImg, gt, initX,xpoints): # (拿测试集中最后一张图片)画类似于文中图8所示的(定位过程可视化) img:【100,100,3】 meanImg:平均图像【100,10,3】 gt:真实点坐标 [ 6 25 47 69 92] initX:点的初始化位置 [11.0, 31.0, 51.0, 71.0, 91.0] xpoints:5个点的定位过程 {'4': [91.0], '3': [71.0, 76.0, 81.0, 86.0, 91.0, 96.0, 101.0, 106.0, 111.0, 116.0, 121.0], '2': [51.0, 56.0, 61.0, 66.0, 71.0, 76.0, 81.0, 86.0, 91.0, 96.0, 101.0], '1': [31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 61.0, 66.0, 71.0, 76.0, 81.0], '0': [11.0]}
    # xpoints{'5':[2,3,4],'4':[1,1,1]}
    margin = 6 # 间隔
    temp = np.ones((margin*2+100, margin*2+100, 3)) * 255 # shape:(112, 112, 3) 值全是255(白色)
    img = img + meanImg # 给img加上平均值(因为之前剪去了,这里要加回来)
    temp[margin:margin+100, margin:margin+100,:] = img # 赋值。把img放到【6:106,6:106,3】的位置
    initY = np.array([11, 31, 51, 71, 91]) # Y
    initY = initY + margin # 将y值做平移
    initX = np.array(initX) # 转换格式
    initX = initX + margin # 同理
    gt = gt+margin 
    for i in xpoints.keys(): # 同理
        xpoints[i] = np.array(xpoints[i]) + margin
    finalImg = [] # 最终图像
    # 画gt
    for i in range(5):
        if gt[i]>0:
            x = gt[i]
            y = initY[i]
            pt = (int(x),int(y))
            cv2.circle(temp, pt, 8, (0, 0, 255), 2) # 根据给定的圆心和半径等画圆。temp:输入的图片data pt:圆心位置 8:圆的半径 (0, 0, 255): 圆的颜色 2:圆形轮廓的粗细

    # 五行
    for k in np.arange(5, 0, -1): # 每一行初始化。
        # 画一行
        oneLine = []
        ttemp = copy.deepcopy(temp) # 深拷贝
        # 画当前行的初始图(除了当前点之外的所有点)
        for i in np.arange(5, k, -1):  # 控制行
            x = xpoints[str(i-1)][-1] # 一行的最后一个元素
            y = initY[i - 1]
            pt = (int(x),int(y))
            cv2.line(ttemp, pt, pt, (255, 0, 0), 4)
        for i in np.arange(k - 1, 0, -1): # 画一列 np.arange(三个参数):第一个参数为起点,第二个参数为终点(取不到),第三个参数为步长。
            # pdb.set_trace()
            x = xpoints[str(i - 1)][0] # 一行的第一个元素
            y = initY[i - 1]
            pt = (int(x),int(y)) # 【77,77】。。。【17,17】
            cv2.line(ttemp, pt, pt, (255, 0, 0), 4) # cv2.line:用于在图像中划线 ttemp:要划的线所在的图像 pt:直线起点 pt:直线终点 (255, 0, 0):直线颜色(绿) 4:线条粗细
        y = initY[k-1] # 97
        for i in range(len(xpoints[str(k-1)])): # np.arange(一个参数):参数值为终点,起点取默认值0,步长取默认值1。控制每一行中的列
            tttemp = copy.deepcopy(ttemp)
            x = xpoints[str(k-1)][i] # 97
            pt = (int(x),int(y))
            cv2.line(tttemp, pt, pt, (0, 255, 0), 4)
            # pdb.set_trace()
            oneLine.append(tttemp)
        finalImg.append(oneLine) # 一行
    return finalImg # 是一个列表,记录了所有点移动过程对应的图像

def catFinalImg(finalImg):
    maxlen = 0
    for line in finalImg: # 找出最长的那一行
        if len(line) > maxlen:
            maxlen = len(line)
    x = finalImg[0][0].shape[0] 
    w = x*maxlen
    h = x*5
    img = np.ones((h, w, 3)) * 255 # 画出一个白板
    for i in range(len(finalImg)): # 画图
        for j in range(len(finalImg[i])):
            img[x * i:x * (i + 1), x * j:x * (j + 1),:] = finalImg[i][j]
    return img

datasets.py

datasets.py文件解读

# coding=utf-8
# ----tuzixini----
# MACOS Python3.6.6
'''
载入 self_lane数据集
'''
import pdb
import collections
from torch.utils import data
from scipy import io as sio
from torch.utils.data import DataLoader
import os.path as osp
import json
import numpy as np
from PIL import Image


def getData(cfg):
    if cfg.DATA.NAME =='SelfLane': # 不运行
        trainset = SelfLane(cfg.DATA.TRAIN_LIST)
        valset = SelfLane(cfg.DATA.VAL_LIST)
        trainloader = DataLoader(trainset, 
                                num_workers=cfg.DATA.NUM_WORKS,
                                batch_size=cfg.DATA.TRAIN_IMGBS, 
                                shuffle=cfg.DATA.IMGSHUFFLE) # DataLoader有很多参数,但这里大都采用默认的,暂不深究
        valloader = DataLoader(valset, 
                            num_workers=cfg.DATA.NUM_WORKS,
                            batch_size=cfg.DATA.VAL_IMGBS, 
                            shuffle=cfg.DATA.IMGSHUFFLE)
        meanImg = sio.loadmat(cfg.DATA.MEAN_IMG_PATH)
        meanImg = meanImg['meanImg']
        return meanImg,trainloader, valloader
    if cfg.DATA.NAME == 'TuSimpleLane': # 从这里开始
        trainset = TuSimpleLane(cfg.DATA.ROOT,cfg.DATA.TRAIN_LIST,isTrain=True) # 实例化训练集
        valset =TuSimpleLane(cfg.DATA.ROOT,cfg.DATA.VAL_LIST,isTrain=False) # 实例化测试集
        trainloader =DataLoader(trainset,batch_size=cfg.DATA.TRAIN_IMGBS,shuffle=cfg.DATA.IMGSHUFFLE,num_workers=cfg.DATA.NUM_WORKS) # trainset:13704 batch_size:100 shuffle:True num_workers:8
        valloader =DataLoader(valset,batch_size=cfg.DATA.VAL_IMGBS,shuffle=cfg.DATA.IMGSHUFFLE,num_workers=cfg.DATA.NUM_WORKS) # valset:9695 batch_size batch_size:1 shuffle:True NUM_WORKS:8
        meanImg =np.load(cfg.DATA.MEAN_IMG_PATH) # 图像平均值
        return meanImg,trainloader,valloader # 图像平均值,训练文件,测试文件

class TuSimpleLane(data.Dataset):
    def __init__(self, dataroot, ListPath, isTrain=True,im_tf=None, gt_tf=None):  # dataroot:'/home/wqf/tusimple/MyTuSimpleLane' ListPath:'/home/wqf/tusimple/train_DRL_list.json'
        if isTrain:
            self.root = osp.join(dataroot, 'train') # '/home/wqf/tusimple/MyTuSimpleLane/train'
        else:
            self.root = osp.join(dataroot, 'test')
        self.root = osp.join(self.root, 'DRL', 'resize') # '/home/wqf/tusimple/MyTuSimpleLane/train/DRL/resize'
        with open(ListPath, 'r') as f:
            self.pathList= json.load(f) # {list:13704}
        self.im_tf = im_tf # None
        self.gt_tf = gt_tf # None

    def __getitem__(self, index): 
        # img
        temp = osp.join(self.root, self.pathList[index] + '.png') # 图片名称
        img = np.array(Image.open(temp))  # 载入图片
        temp = osp.join(self.root, self.pathList[index] + '.json') # 打开图片对应json文件
        with open(temp, 'r') as f: # 打开json文件
            data = json.load(f) # 读取
        img = np.array(img) # 将图片转为矩阵形式
        img = img.astype(np.float32) # 数据类型
        cla = np.array(data['class']) # 读取类别
        gt = np.array(data['gt']) # 读取真实点坐标
        return cla, img, gt # 返回类别,图片,真实点坐标

    def __len__(self): # 应该是默认调用的
        return len(self.pathList) # 返回长度


class SelfLane(data.Dataset):
    def __init__(self, pathList, im_tf=None, gt_tf=None):
        self.pathList = pathList
        self.im_tf = im_tf
        self.gt_tf = gt_tf

    def __getitem__(self, index):
        temp = sio.loadmat(self.pathList[index])
        img = temp['img']
        img = np.array(Image.fromarray(img).resize((100,100)))
        img =img.astype(np.float32)
        # fea = temp['fea']
        cl = np.array(int(temp['class_name'][0]))
        gt = np.array(temp['mark'][0])
        return cl, img, gt

    def __len__(self):
        return len(self.pathList)


class bufferLoader(data.Dataset):
    def __init__(self, buffer, tf=None):
        self.buffer = buffer
        self.tf = tf

    def __getitem__(self, index):
        fea, state, Q = self.buffer[index]
        fea = np.array(fea).astype(np.float32)
        state = np.array(state).astype(np.float32)
        Q = np.array(Q).astype(np.float32)
        if self.tf is not None:
            fea = self.tf(fea)
        return fea, state, Q

    def __len__(self):
        return len(self.buffer)

model.py

# coding=utf-8
# [email protected]
# WIN10 Python3.6.6
# 用途: DRL_Lane Pytorch 实现
# model.py

import torch.nn as nn
import pdb
import torch


class DRL_LANE(nn.Module):
    def __init__(self, cfg):
        super(DRL_LANE,self).__init__() # 调用父类的init函数(关于python,仍需学习)
        # 使用默认的strid和padding
        self.encoder = nn.Sequential( # 编码器(参数详见表2,和表2略有不同,这里能够运行通。好像是因为执行顺序的不同导致的。按照论文所述也没有问题)。
            nn.Conv2d(3, 48, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(48, 96, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(96, 128, kernel_size=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # https://blog.csdn.net/qq_40210472/article/details/87895626 https://blog.csdn.net/weixin_38481963/article/details/109962715 https://blog.csdn.net/jacke121/article/details/104020945
            nn.Conv2d(128, 192, kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(192, 256, kernel_size=2),
            nn.ReLU()) 
        self.fc1 = nn.Sequential(
            nn.Linear(5376, 512),
            nn.ReLU())
        self.fc2 = nn.Sequential(
            nn.Linear(545, 256),
            nn.ReLU())
        self.fc3 = nn.Sequential(
            nn.Linear(256, 4))

    def forward(self, img, state): # 前向传播
        img = self.encoder(img) # 编码器。输出是# 1 x 21 x 256
        img = img.view(img.shape[0],-1) # 类似reshape。5376
        img = self.fc1(img) # 512
        img = torch.cat((img, state),1) # 连接操作。512 + 33 = 545
        img = self.fc2(img) # 256
        img = self.fc3(img) # 4
        return img


def getModel(cfg):
    net = DRL_LANE(cfg) # 创建一个网络
    return net

reward.py

import numpy as np
import cv2
import pdb
from config import cfg


dst_threshold = cfg.DST_THR
reward_terminal_action = cfg.reward_terminal_action
reward_movement_action = cfg.reward_movement_action
reward_invalid_movement_action = cfg.reward_invalid_movement_action
reward_remove_action = cfg.reward_remove_action
# Different actions that the agent can do
num_of_actions = cfg.ACT_NUM
# Actions captures in the history vector
num_of_history = cfg.HIS_NUM


def get_dst(gt_point, cur_point):
    dst = abs(gt_point-cur_point) # 记录真实点和当下点坐标的差值
    return dst


def get_reward_trigger(cur_dst):
    if cur_dst < dst_threshold:
        reward = reward_terminal_action
    else:
        reward = - reward_terminal_action
    return reward


def get_reward_movement(cur_point, last_point, gt_point):
    if gt_point == -100:
        if cur_point == -100:
            reward = reward_movement_action
        else:
            reward = - reward_movement_action
    else:
        cur_dst = get_dst(gt_point, cur_point)
        last_dst = get_dst(gt_point, last_point)
        if cur_dst < last_dst:
            reward = reward_movement_action
        else:
            reward = - reward_movement_action
    return reward


def getRewMov0427(cur_point, last_point, gt_point):
    if gt_point == -20:  # should be removed, but the action is 2 or 3
        reward = - reward_remove_action
    else:  # should be moved, the action is 2 or 3
        cur_dst = get_dst(gt_point, cur_point)
        last_dst = get_dst(gt_point, last_point)
        if cur_dst < last_dst:
            reward = reward_movement_action
        else:
            if cur_point < 0 or cur_point >= 100:  # moved out of the image, the reward is change to -5
                reward = reward_invalid_movement_action
            else:
                reward = - reward_movement_action
    return reward


def getRewRm(cur_dst):
    if cur_dst == 0:
        reward = reward_remove_action
    else:
        reward = - reward_remove_action
    return reward


def update_history_vector(history_vector, action):
    action_vector = np.zeros(num_of_actions) #【0,0,0,0】
    action_vector[action-1] = 1 # 【0,0,0,1】
    # number of real history in the current history vector
    num_real_cur_history = np.size(np.nonzero(history_vector)) # np.nonzero:返回数组中不为0的元素的下标 np.size:返回矩阵的元素个数 0
    updated_history_vector = np.zeros(num_of_actions*num_of_history) # 32个0
    if num_real_cur_history < num_of_history: # 0<8
        aux2 = 0
        for l in range(num_of_actions*num_real_cur_history, num_of_actions*num_real_cur_history+num_of_actions):
            history_vector[l] = action_vector[aux2] # 在对应位置更新history_vector
            aux2 += 1
        return history_vector # [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0., 0. 0. 0. 0. 0. 0. 0. 0.]
    else:
        for j in range(0, num_of_actions*(num_of_history-1)):
            updated_history_vector[j] = history_vector[j+num_of_actions]
        aux = 0
        for k in range(num_of_actions*(num_of_history-1), num_of_actions*num_of_history):
            updated_history_vector[k] = action_vector[aux]
            aux += 1
        return updated_history_vector


def get_state(cur_point, hist_vec):
    history_vector = np.reshape(hist_vec, (num_of_actions*num_of_history, 1)) # reshape
    state = np.vstack((cur_point, history_vector)) # 堆叠
    state = np.squeeze(state) # 去除无用维度
    return state 

验证/测试和可视化

这个文件和上面的train.py大体一致,只不过这里可视化的是验证集中所有图片。据我估计应该是用来断点续训的,如果想要实现续训,还要修改config.py文件中第30行__C.TRAIN.RESUME = True

修改config.py文件

  • 修改cfg.TRAIN.RESUME_PATH变量(visAllVal.py文件,335行)为你真实模型的check point(这里是r'/opt/disk/zzy/project/DRL_lane/DRL_Code_TuSimple/DRL_Lane_Pytorch/exp/20-04-19-23-36_TuSimpleLane/EP_62_HitRat0.86465.pth'
cfg.TRAIN.RESUME_PATH = '/opt/disk/zzy/project/DRL_lane/DRL_Code_TuSimple/DRL_Lane_Pytorch/exp/20-04-19-23-36_TuSimpleLane/EP_62_HitRat0.86465.pth'
cd $CODEPATH
python visAllVal.py

可视化文件将会被保存在cfg.EXP.PATH/_VAL_VIS

# coding=utf-8
# [email protected]
# WIN10 Python3.6.6
# 用途: DRL_Lane Pytorch 实现
# train.py
import os
import pdb
import torch
import scipy
import random
import collections
import numpy as np
import os.path as osp
from tqdm import tqdm
from torchvision import transforms
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from PIL import Image

from config import cfg
import utils
import datasets
import model
import reward
from utils import Timer

class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        os.makedirs(self.cfg.EXP.PATH, exist_ok=True)
        os.makedirs(self.cfg.EXP.PATH+'/valimg', exist_ok=True)
        # logger
        self.writer = SummaryWriter(self.cfg.EXP.PATH)
        # 计时器
        self.t = {'iter': Timer(), 'train': Timer(), 'val': Timer()}
        # 保存实验环境 # TODO: 启用
        temp = os.path.join(self.cfg.EXP.PATH, 'code')
        utils.copy_cur_env('./', temp, exception='exp')
        # 读取数据集
        self.meanImg, self.trainloader, self.valloader = datasets.getData(self.cfg)
        # 定义网络
        self.net = model.getModel(cfg)
        # 损失函数
        self.criterion = torch.nn.MSELoss()
        # 优化器
        self.optimizer = torch.optim.Adam(
            self.net.parameters(),
            lr=self.cfg.TRAIN.LR,
            weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
        # 初始化一些变量
        self.beginEpoch = 1
        self.batch = 1
        self.bestacc = 0
        # 载入预训练模型
        if self.cfg.TRAIN.RESUME:
            print('Loading Model..........')
            saved_state = torch.load(self.cfg.TRAIN.RESUME_PATH) # pth文件
            self.net.load_state_dict(saved_state['weights'])
            self.beginEpoch = saved_state['epoch']
            self.batch = saved_state['batch']
            self.bestacc = saved_state['bestacc']
        # GPU设定
        self.gpu = torch.cuda.is_available() and self.cfg.TRAIN.USE_GPU
        self.device = 'cuda' if self.gpu else 'cpu'
        if self.gpu:
            torch.cuda.set_device(self.cfg.TRAIN.GPU_ID[0])
            self.criterion.cuda()
            if len(self.cfg.TRAIN.GPU_ID) > 1:
                self.net = torch.nn.DataParallel(
                    self.net, device_ids=self.cfg.TRAIN.GPU_ID)
            self.net = self.net.cuda()
        else:
            self.net = self.net.cpu()
            self.criterion.cpu()

    def train(self):
        for self.epoch in range(self.beginEpoch, self.cfg.TRAIN.MAX_EPOCH):
            # 训练一个Epoch
            self.t['train'].tic()
            self.trainEpoch()
            temp = self.t['train'].toc(average=False)
            print('Train time of Epoch {} is : {:.2f}s'.format(self.epoch, temp))
            # 在验证集上测试
            self.t['val'].tic()
            acc = self.val()
            temp = self.t['val'].toc(average=False)
            print('Val time of/after Epoch {} is : {:.2f}s'.format(self.epoch, temp))
            print('Acc for Epoch {} is : {:.4f}'.format(self.epoch, acc))
            print('BestAcc is:{:.4f}'.format(self.bestacc))
            self.writer.add_scalar('ValHitRate_PerEpoch', acc, self.epoch)
            # 保存模型
            if acc > self.bestacc:
                self.bestacc = acc
                temp = "HitRat{:.5f}".format(acc)
                self.save(temp)

    def trainEpoch(self):
        self.buffer = utils.ExBuffer(self.cfg.BUFFER_CAP)
        self.buffer.clean()
        print('Build Buffer.........')
        for batch_index, (clas, imgs, gts) in tqdm(enumerate(self.trainloader)):
            clas = clas.numpy()
            imgs = imgs.numpy()
            gts = gts.numpy()
            for j in range(len(imgs)):
                self.img = imgs[j]-self.meanImg
                self.cl = clas[j]
                self.gt = gts[j]
                if self.cl == 1:
                    self.initMarkX = [91.0, 71.0, 51.0, 31.0, 11.0]
                else:
                    self.initMarkX = [11.0, 31.0, 51.0, 71.0, 91.0]
                self.updateBuffer()
                if self.trainFlag:
                    self.trainBuffer()
                    print('Build Buffer.........')
        self.trainBuffer()

    def trainBuffer(self):
        print('Training..........')
        self.net.train()
        tf = transforms.ToTensor()
        dataset = datasets.bufferLoader(self.buffer.buffer,tf=tf)
        loader = DataLoader(dataset, num_workers=self.cfg.DATA.NUM_WORKS, batch_size=self.cfg.DATA.BS, shuffle=self.cfg.DATA.SHUFFLE)
        for epoch in tqdm(range(self.cfg.TRAIN.INER_EPOCH)):
            for fea, state, Q in loader:
                fea, state, Q = fea.to(self.device), state.to(self.device), Q.to(self.device)
                self.optimizer.zero_grad()
                output = self.net(fea, state)
                loss = self.criterion(output, Q)
                loss.backward()
                self.optimizer.step()
                self.writer.add_scalar('trian_loss', loss.item(), self.batch)
                self.batch += 1

    def val(self):
        self.net.eval()
        hit_cnt = 0
        detect_hit_cnt = 0
        test_cnt = 0
        sup_cnt = 0
        steps_cnt = 0
        for valIndex, (cl, img, gt) in tqdm(enumerate(self.valloader)):
            img = np.squeeze(img.numpy())
            pdb.set_trace() # pdb是python用于调试代码的常用库。程序运行到这里就会暂停
            cl = np.squeeze(cl.numpy())
            gt = np.squeeze(gt.numpy())
            img = img - self.meanImg
            if cl == 1:
                initMarkX = [91.0, 71.0, 51.0, 31.0, 11.0]
            else:
                initMarkX = [11.0, 31.0, 51.0, 71.0, 91.0]
            # 循环处理五个landmark point
            xpoints = dict()
            for k in np.arange(self.cfg.LANDMARK_NUM, 0, -1):  # 5.4.3.2.1
                cur_x = []
                step = 0
                allActList = np.zeros(self.cfg.MAX_STEP)
                status = 1
                if gt[k - 1] == -1:
                    gt[k - 1] = -20
                gt_point = gt[k - 1]
                fea_t = np.array(img[(k - 1) * 20:k * 20,:,:])
                fea_t = np.transpose(fea_t, (2, 0, 1))
                fea_t = fea_t.astype(np.float32)
                fea_t = fea_t.reshape((1,fea_t.shape[0],fea_t.shape[1],fea_t.shape[2]))
                fea_t = torch.from_numpy(fea_t).cuda()
                cur_point = initMarkX[k - 1]
                cur_x.append(cur_point)
                if self.cfg.HIS_NUM == 0:
                    hist_vec = []
                else:
                    hist_vec = np.zeros([self.cfg.ACT_NUM * self.cfg.HIS_NUM])
                state = reward.get_state(cur_point, hist_vec)
                while (status == 1) & (step < self.cfg.MAX_STEP):
                    step += 1
                    state= state.astype(np.float32).reshape((1,-1))
                    state = torch.from_numpy(state).cuda()
                    qval = np.squeeze(self.net(fea_t, state).detach().cpu().numpy())
                    action = (np.argmax(qval)) + 1
                    allActList[step - 1] = action
                    if action != 4:
                        if action == 1:
                            cur_point = -20
                        elif action == 2:
                            cur_point -= 5
                        elif action == 3:
                            cur_point += 5
                        cur_x.append(cur_point)
                    else:
                        status = 0
                    if self.cfg.HIS_NUM != 0:
                        hist_vec = reward.update_history_vector(
                            hist_vec, action)
                    state = reward.get_state(cur_point, hist_vec)
                steps_cnt += step
                finalPoint = cur_point
                finalDist = abs(finalPoint - gt_point)
                det_dst = abs(initMarkX[k-1]-gt_point)
                if det_dst < self.cfg.DST_THR:
                    detect_hit_cnt += 1
                test_cnt += 1
                if finalDist <= self.cfg.DST_THR:
                    hit_cnt += 1
                if finalDist <= det_dst:
                    sup_cnt += 1
                xpoints[str(k-1)] = cur_x
            finImg = utils.visOneLane(img, self.meanImg, gt, initMarkX, xpoints)
            finImg = utils.catFinalImg(finImg)
            tempPath = osp.join(self.cfg.EXP.PATH,'valimg','val_vis_'+str(valIndex)+'.png')
            Image.fromarray(finImg.astype('uint8')).save(tempPath)
        finImg=np.transpose(finImg, (2,0,1))
        self.writer.add_image('Val_Vis',finImg,self.epoch)
        self.writer.add_scalar('Val_RL_HR', float(hit_cnt) / test_cnt, self.epoch)
        self.writer.add_scalar('Val_Hit_Cnt',hit_cnt,self.epoch)
        self.writer.add_scalar('Val_Det_HR', float(detect_hit_cnt) / test_cnt, self.epoch)
        self.writer.add_scalar('Val_Det_Hit_Cnt',detect_hit_cnt,self.epoch)
        self.writer.add_scalar('Val_RLsupDet_HR', float(sup_cnt)/test_cnt, self.epoch)
        self.writer.add_scalar('Val_Average_Step', float(steps_cnt) / ((valIndex + 1) * 5), self.epoch)
        return float(hit_cnt) / test_cnt

    def updateBuffer(self):
        self.trainFlag = False
        buf = collections.namedtuple('buf', field_names=['fea', 'state', 'Q'])
        # generateExpReplay
        for k in np.arange(cfg.LANDMARK_NUM, 0, -1):  # [5,4,3,2,1]
            if self.gt[k - 1] == -1:
                self.gt[k - 1] = -20
            gt_point = self.gt[k - 1]
            # generate actions
            # status indicates whether the agent is still alive and has not triggered the terminal action
            status = 1
            step = 0
            cur_point = self.initMarkX[k - 1]
            landmark_fea = np.array(self.img[(k - 1) * 20:k * 20, :, :])
            landmark_fea_trans = np.reshape(landmark_fea, (1, 20, 100, 3))
            if self.cfg.HIS_NUM == 0:
                hist_vec = []
            else:
                hist_vec = np.zeros([self.cfg.HIS_NUM*self.cfg.ACT_NUM])
            state = reward.get_state(cur_point, hist_vec)
            cur_dst = reward.get_dst(gt_point, cur_point)
            last_point = cur_point
            last_dst = cur_dst
            while (status == 1) & (step < self.cfg.MAX_STEP):
                rew = []
                qval = np.array(self.predict(landmark_fea, state))
                step += 1
                # 挑选action 计算reward
                # we force terminal action in case actual IoU is higher than 0.5, to train faster the agent
                if cur_dst < self.cfg.DST_THR:
                    action = 4
                # epsilon-greedy policy
                elif random.random() < self.cfg.EPSILON:
                    action = np.random.randint(1, 5)
                else:
                    action = (np.argmax(qval)) + 1
                # terminal action
                if action == 4:
                    rew = reward.get_reward_trigger(cur_dst)
                # move action,performe the crop of the corresponding subregion
                elif action == 1:
                    cur_point = -20
                    cur_dst = reward.get_dst(gt_point, cur_point)
                    rew = reward.getRewRm(cur_dst)
                    last_dst = cur_dst
                    last_point = cur_point
                elif action == 2:  # to left
                    cur_point = cur_point - 5
                    cur_dst = reward.get_dst(gt_point, cur_point)
                    rew = reward.getRewMov0427(cur_point, last_point, gt_point)
                    last_dst = cur_dst
                    last_point = cur_point
                elif action == 3:  # to right
                    cur_point = cur_point + 5
                    cur_dst = reward.get_dst(gt_point, cur_point)
                    rew = reward.getRewMov0427(cur_point, last_point, gt_point)
                    last_dst = cur_dst
                    last_point = cur_point
                if self.cfg.HIS_NUM != 0:
                    hist_vec = reward.update_history_vector(hist_vec, action)
                new_state = reward.get_state(cur_point, hist_vec)
                # 计算 用来训练的Q值
                if action == 4:
                    temp = rew
                else:
                    temp = np.array(self.predict(landmark_fea, new_state))
                    temp = np.argmax(temp)
                    temp = rew + self.cfg.GAMMA * temp
                qval[action-1] = temp
                # 将数据存入buffer
                if self.buffer.ready2train:
                    self.trainFlag = True
                    break
                else:
                    temp = buf(landmark_fea, state, qval)
                    self.buffer.append(temp)
                if action == 4:
                    status = 0
                state = new_state

    def predict(self, fea, sta):
        self.net.eval()
        fea = np.transpose(fea, (2, 0, 1))
        fea = fea.astype(np.float32)
        fea = fea.reshape((1,fea.shape[0],fea.shape[1],fea.shape[2]))
        fea = torch.from_numpy(fea).cuda()
        sta = sta.astype(np.float32)
        sta = sta.reshape((1,-1))
        sta = torch.from_numpy(sta).cuda()

        x = self.net(fea, sta)
        return np.squeeze(x.data.detach().cpu().numpy())

    def save(self,temp):
        temp = 'EP_'+str(self.epoch)+'_'+temp+'.pth'
        path = osp.join(self.cfg.EXP.PATH, temp)
        if (not self.cfg.TRAIN.USE_GPU) or (len(self.cfg.TRAIN.GPU_ID) == 1):
            to_saved_weight = self.net.state_dict()
        else:
            to_saved_weight = self.net.module.state_dict()
        toSave = {
            'weights': to_saved_weight,
            'epoch': self.epoch,
            'batch': self.batch,
            'bestacc': self.bestacc
        }
        torch.save(toSave, path)
        print('Model Saved!')


if __name__ == "__main__":
    utils.setup_seed(cfg.SEED)
    #cfg.TRAIN.RESUME = True
    cfg.TRAIN.RESUME_PATH = '/opt/disk/zzy/project/DRL_lane/DRL_Code_TuSimple/DRL_Lane_Pytorch/exp/20-04-19-23-36_TuSimpleLane/EP_62_HitRat0.86465.pth'
    cfg.EXP.PATH = cfg.EXP.PATH+'_VAL_VIS'
    MyTrainer = trainer(cfg)
    acc = MyTrainer.val()
    print()

关于namedtuple
基于深度强化学习的车道线检测和定位(Deep reinforcement learning based lane detection and localization) 论文解读+代码复现_第27张图片
完结。

你可能感兴趣的:(车道线检测/道路边缘检测,python,计算机视觉,机器学习)