PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)

官方源代码:https://github.com/lucastabelini/PolyLaneNet

摘要

车道线检测要求实时性(>=30FPS)。

引言

车道线检测很重要巴拉巴拉。。。
车道线检测分为传统方法(手工特征提取+曲线拟合)和深度学习方法。深度学习方法能够更鲁棒,但仍有一些问题需要解决:

  1. 许多方法将车道线检测作为一个two-step的过程:feature extraction and curve fitting。

(但大多数方法都是通过基于分割的模型来提取特征的,不具有时效性。单独的分割步骤不足以提供车道标记估计,因为分割地图必须经过后处理,以输出交通线路。)
(可能忽略global information;但全局信息在有遮挡和阴影时是十分重要的)

  1. 不公开方法、源代码、数据集等。
  2. 评估方案仍有改进的余地。

(只用来自美国的数据集;评估指标过于宽松)

因此本文方法对上述问题进行了改进:

  1. 一步到位(端到端)
  2. 美国以外的道路测试
  3. 评价标准更加严格
  4. 提供源代码和训练好的模型

相关工作

没啥东西,可能不公开源代码和数据集就是在“耍流氓”。

POLYLANENET

模型定义

输入
单目相机采集到的图片

输出
Mmax:候选车道线(以多项式表示)
h:垂直高度(限制车道线标记的上限,所有车道线共享)
对每个车道线j的偏移量sj
预测的置信度cj∈[0; 1]

网络组成
backbone network:用于提取特征
fully connected layer输出:Mmax + 1个
(第1; ……;第Mmax:用于车道线标记预测)
( 第Mmax + 1:用于输出h)
由于PolyLaneNet采用多项式来代表车道线而不是用一系列点,因此,对每一个输出j,j = 1;……;Mmax,模型估计的是多项式的系数。

系数:在这里插入图片描述
车道线多项式:PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第1张图片
K:多项式的次数

模型可表达如下
在这里插入图片描述
其中I是输入图像;θ是模型参数;当置信度大于给定阈值时才认为检测到了车道线。
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第2张图片

模型训练

对于给定图片,M是图中标注好的车道线数量。一般来讲,M<=4满足了大多数交通场景的要求。神经单元j代表了车道线j(j = 1;……;M)。因此第M+1到Mmax个输出在损失函数中应该不予考虑。
一个标注好的车道线j由一系列点表示,如下式所示:
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第3张图片
其中,
在这里插入图片描述
对每一个i=1;……;N-1。
根据经验法则,N的越高,表达的结构更丰富。
我们假设车道线标记PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第4张图片
的排序规则如下:根据图像底部的点的x坐标。即
在这里插入图片描述
对每一个车道线标记j,垂直偏移量在这里插入图片描述
被设定为
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第5张图片
置信度定义如下:
(如果车道线比M小,置信度为1,否则是0)
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第6张图片
对于一个单一图像,多任务损失函数定义如下:
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第7张图片(Lcls:二分类损失。表示车道线是否被检测到。
两个Lreg:MSE损失。表示车道线下限/上限和真实下限/上限之间的距离。下限是单个的,上限是共享的
Lp({PJ},{Lj}:其中{Lj}是ground-truth的一堆点;{PJ}是预测的多项式映射之后的一堆点。)
其中系数Wp,Ws,Wc,Wh用于平衡权重,回归Lreg和Lcls分别是MSE和BCE函数。Lp损失函数表示了衡量了多项式pj(等式1)对车道线点的表示程度。
考虑到标注的x坐标在这里插入图片描述(ground-truth)
在这里插入图片描述(predict)
其中,
在这里插入图片描述(如果预测的点已经很接近真实点(<τloss),那么xi,j=0;否则xi,j=pj(y*i,j))。
(这里不是很懂为什么要是0,如果是0,那么下面MSE误差不是大了吗?)

其中,在这里插入图片描述
是一个经验阈值,可以减少对已经很好地对齐的点的损失的考虑。

这种效应的产生是因为车道标记包含了几个具有不同采样差异的点
(例如,靠近相机的点比远离相机的点密度大)。

Lp定义如下(计算MSE):
在这里插入图片描述

实验方法论

A.数据集

三个:TuSimple、LLAMAS、ELAS。

TuSimple数据集用于定量分析。(有6,408张标注好的数据集,像素是1280x720)
(3,268张用于训练, 358张用于验证,2,782张用于测试)

LLAMAS和ELAS用于定性分析。LLAMAS有58,269张用于训练,20,844张用于验证,20,929张用于测试)像素是1280x717。

TuSimple和LLAMAS数据集来自美国,由于对于LLAMAS的基准和测试集注释都还不可用,所以只给出定性的结果。

ELAS数据集采集自Brazil的不同城市。分辨率是640x480;有16,993张。由于数据集最初是为非基于学习的方法提出的,它不提供训练/测试分割。因此我们分割该数据集为11,036张用于训练,5,957张用于测试。
ELAS数据集不同于其他二者,ELAS只有自我车道被标记,数据集还为车道检测任务提供了其他类型的有用信息,如车道类型(例如实线或虚线,白色或黄色),但在本文中没有使用。

B.实验细节

骨架网络:EfficientNet-b0
在TuSimple训练中,数据增强的应用概率为10/11。具体如下:
旋转角度:在这里插入图片描述
以0.5的概率水平翻转。
随即裁剪到1152x648像素。

数据增强之后,应用了下列转换:

  1. resize到640x360
  2. normalization(以ImageNet的均值和标准差)
  3. Adam优化器
  4. 以及具有3e-4初始学习速率和770个epochs的余弦退火学习速率调度程序

训练部分跑了2695 epochs,在Titan V上用了35个小时。
batch size=16
在经过ImageNet预训练的网络上进行训练。
三阶多项式选为默认值。
超参数选择为:
Ws = Wc = Wh = 1;Wp = 300 (调出来的)
τloss=20pixels

在测试阶段,车道线标记的置信度cj<0.5的被忽略。

C.评价指标

来自TuSimple’s benchmark。
accuracy (Acc), false positive (FP) and false negative (FN) rates。

一个预测的车道线标记满足如下条件则被认为是true positive(即正确):
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第8张图片
Acc>=ε
其中,τloss=20pixels;ε=0.85。

每条车道线计算上述三个指标后,每幅图像平均;所有图像再平均。

但上述三个指标不够严格,我们进一步采用下述指标:
Lane Position Deviation(LPD):更好地捕捉模型对自车远近视角的精度。这是对自我车道线的预测和groundtruth误差。

ego-lane:靠近图像底部中间部分的车道标记是组成ego车道的标记,即一个车道标记在左侧,另一个车道标记在右侧。

还有两个与速度相关的指标:FPS(多少帧每秒)和MACs(比较可能运行在不同框架和设置上的不同方法)

D.定量评价

SOTA方法对比。

多项式次数对比。

消融实验。

定性评价
对于定性结果,进行了广泛的评价。将在TuSimple上训练的模型作为预训练,训练了3个模型:2个在ELAS上(有和没有车道标记类型分类上各一个),1个在LLAMAS上。

在ELAS上,该模型被训练了385个额外的epochs(所选择的学习速率调度程序的一半周期,其中学习速率将是最低的)。

在LLAMAS上,该模型被训练了75个额外的epochs,ELAS上使用的迭代次数的近似值,因为LLAMAS的训练集比ELAS的训练集大五倍左右。

车道标记类型分类的实验是PolyLaneNet的直接扩展,在PolyLaneNet中,对每个车道预测一个类别,这表明扩展我们的模型是多么简单。

结果

SOTA方法对比
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第9张图片
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第10张图片
由于数据集中的大多数车道标志可以很好地用一阶多项式(即直线)表示,神经网络对预测线有偏置,因此对于曲率突出的车道标志性能较差

多项式次数对比
从用于表示车道标记的多项式度来看,使用低阶多项式时,精度上的微小差异表明数据集是多么的不平衡。使用一阶多项式(即直线)仅降低了0.35 p.p。
另一个重要因素是基准测试用于评估模型性能的度量。然而,LPD度量能够更好地捕捉使用一阶多项式训练的模型与其他模型之间的差异。
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第11张图片
可以看出,TuSimple的度量值并不惩罚那些仅在距离车辆较近的车道上准确的预测,因为在图像中,这部分路段看起来几乎是直线的(即,可以用一阶多项式很好地表示),因为阈值可能会隐藏这些错误。

消融实验
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第12张图片
虽然EfficientNet-b1获得了最高的精度,但我们没有在其他实验中使用它,因为在我们的实验中,精度增益不显著,也不一致。
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第13张图片
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第14张图片

定性评价

由于这些情况下的图像有一个非常不同的结构(例如,汽车没有朝着道路的方向行驶),在这种情况下,少量的样本可能不足以让模型了解它。

结论

在未来的工作中,可以在不同的车道检测方法(如分割)中使用的指标,以及更好地突出车道检测方法中的缺陷。

(关于车道线检测评价指标的补充)

本文共用了六个指标。关于ACC、TP、FP、FPS这四个常用的相信不用说了。
下面结合本文引用的下述论文描述其他指标:

R. K. Satzoda and M. M. Trivedi, “On Performance Evaluation Metrics for Lane Estimation,” 2014 22nd International Conference on Pattern Recognition, Stockholm, Sweden, 2014, pp. 2625-2630, doi: 10.1109/ICPR.2014.453.

车道位置偏差(Lane Position Deviation,LPD)

用于确定道路场景中估计车道的准确性。尤其是从 ego-vehicle的视角来看。参考图2,车道特征(图2中的(b))用于估计车道,其由图2中的虚线(d)表示。车道位置偏差(图2中的(e))测量通过连接实际车道标线(图2中的(h))获得的检测车道(d)与实际车道的偏差。
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第15张图片从图2可以看出,它捕获了ego-vehicle近景深和远景深中车道估计过程的精度。此外,该方法还评估了用于确定车道曲率的道路模型的准确性。对于给定的车道估计算法,可以利用不同参数来确定它们对LPD的影响。

参考示出可能的输入图像场景的图3,让我们考虑实线是输入图像中地面真值中的实际车道。让我们考虑下面公式的左车道。让虚线表示由车道估计方法确定的左车道位置。LPD度量确定图像场景中实际和检测车道位置之间、ymax和ymin位置之间x方向上的平均偏差δLPD。
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第16张图片
计算公式如下:
在这里插入图片描述
其中,ymin≤y≤ymax是测量车道偏离的道路场景的选定区域。可以使用适当的因子(基于摄像机校准)缩放相同的度量,以确定世界坐标系中的车道位置偏差。

MACs

先说下面几个概念:
FLOPS 注意全部大写 是floating point of per second的缩写,意指每秒浮点运算次数。用来衡量硬件的性能。

FLOPs 是floating point of operations的缩写,是浮点运算次数,可以用来衡量算法/模型复杂度。

关于进一步的内容,可以参考下述引用:

https://blog.csdn.net/weixin_39833897/article/details/105807172


衡量计算量除了FLOPs外还有一种概念是求MACs(Multiply Accumulate)乘积累加运算次数,一次乘积,然后把这个乘积和另外一个数求和就叫一次MAC,显然与上面计算结果的关系就在于是否要乘2的关系。

MACS 每秒执行的定点乘累加操作次数的缩写,它是衡量计算机定点处理能力的量,这个量经常用在那些需要大量定点乘法累加运算的科学运算中,记为MACS。

关于进一步的内容,可以参考下述引用:

https://www.mobibrw.com/2019/17864
https://www.wikiwand.com/zh-cn/%E4%B9%98%E7%A9%8D%E7%B4%AF%E5%8A%A0%E9%81%8B%E7%AE%97
http://blog.sina.com.cn/s/blog_ebbe6d790102uzcy.html
https://blog.csdn.net/m0_38065572/article/details/104610524

网上说什么的都有,我认为文中的MACs指的是是FLOPs的2倍,用来衡量算法/模型复杂度。

粗略地说,一个乘法累积(Multiply-Accumulate,MAC)相当于2次FLOPs。MACs使用以下库进行计算:

https://github.com/mit-han-lab/torchprofile

代码解读

该论文被ICPR 2020接收了。

安装

官方是Python 3.5.2,但更高版本理论兼容。
安装依赖:

pip install -r requirements.txt

注意事项

这里需要注意python和pytroch版本的对应。requirements.txt中那些版本对应的是3.5.2的python,注意找你对应版本的包的安装,以及包之间的函数的变化。(我目前觉得最省事的就是重建一个环境)
不然导致pytroch调用不了GPU,又去重新配置了下。

(这时我的电脑报了个错,不知道是不是大家都这样,但这里还是记录一下。
PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第17张图片好像和网络什么相关,很简单,按照下述教程输入两行命令即解决。

https://www.jianshu.com/p/547d0b447490

PolyLaneNet:基于深度多项式回归的车道估计(PolyLaneNet: Lane Estimation via Deep Polynomial Regression)_第18张图片

export all_proxy="socks5://127.0.0.1:1080"
unset all_proxy && unset ALL_PROXY

用法

训练

训练的每个设置都是通过YAML配置文件配置的。因此,为了训练模型,你必须设置配置文件,示例如下:

# Training settings
exps_dir: 'experiments' # 实验目录的根目录路径(不仅是您要运行的目录)
iter_log_interval: 1 # 每N次迭代记录一次训练迭代
iter_time_window: 100 # 移动平均迭代窗口用于打印loss指标
model_save_interval: 1 # 每N个epochs保存一次模型
seed: 0 # 随机数种子
backup: drive:polylanenet-experiments # 训练结束后,将使用rclone自动上传实验目录。 如果您不想这样做,请留空。
model:
  name: PolyRegression
  parameters:
    num_outputs: 35 # (5 lanes) * (1 conf + 2 (upper & lower) + 4 poly coeffs)
    pretrained: true
    backbone: 'efficientnet-b0'
    pred_category: false
loss_parameters:
  conf_weight: 1
  lower_weight: 1
  upper_weight: 1
  cls_weight: 0
  poly_weight: 300
batch_size: 16
epochs: 2695
optimizer:
  name: Adam
  parameters:
    lr: 3.0e-4
lr_scheduler:
  name: CosineAnnealingLR
  parameters:
    T_max: 385

# Testing settings
test_parameters:
  conf_threshold: 0.5 # 将置信度低于此置信度的预测设置为0(即,将其设置为对于指标无效)

# Dataset settings
datasets:
  train:
    type: PointsDataset
    parameters:
      dataset: tusimple
      split: train
      img_size: [360, 640]
      normalize: true
      aug_chance: 0.9090909090909091 # 10/11
      augmentations: # ImgAug augmentations
       - name: Affine
         parameters:
           rotate: !!python/tuple [-10, 10]
       - name: HorizontalFlip
         parameters:
           p: 0.5
       - name: CropToFixedSize
         parameters:
           width: 1152
           height: 648
      root: "datasets/tusimple" # Dataset root

  test: &test
    type: PointsDataset
    parameters:
      dataset: tusimple
      split: val
      img_size: [360, 640]
      root: "datasets/tusimple"
      normalize: true
      augmentations: []

  # val = test
  val:
    <<: *test`在这里插入代码片`

创建配置文件后,运行训练脚本:

python train.py --exp_name tusimple --cfg config.yaml

脚本选项有:

  --exp_name            实验名称
  --cfg                 Config file for the training (.yaml)
  --resume              断点续训。 如果训练过程被中断,请使用相同的参数和此选项再次运行它,以从最后一个检查点恢复训练。
  --validate            Wheter to validate during the training session. Was not in our experiments, which means it has not been thoroughly tested.
  --deterministic       set cudnn.deterministic = True and cudnn.benchmark = False

测试

训练好后,运行test.py脚本得到评价指标:

python test.py --exp_name tusimple --cfg config.yaml --epoch 2695

脚本选项有:

  --exp_name            Experiment name.
  --cfg                 Config file for the test (.yaml). (probably the same one used in the training)
  --epoch EPOCH         Epoch to test the model on
  --batch_size          Number of images per batch
  --view                Show predictions. Will draw the predictions in an image and then show it (cv.imshow)

复现论文效果

模型

所有训练好的模型:

https://drive.google.com/drive/folders/1oyZncVnUB1GRJl5L4oXz50RkcNFM_FFC

数据集

https://github.com/TuSimple/tusimple-benchmark

https://github.com/rodrigoberriel/ego-lane-analysis-system/tree/master/datasets
https://unsupervised-llamas.com/llamas/

具体做法

要复现结果,您可以使用相同的设置重新训练模型(其结果应与文中的设置非常接近),也可以仅测试模型。 如果要重新训练,则仅需要修改相应的YAML设置文件,您可以在cfgs目录中找到该文件。 如果您只想通过测试模型来重现文中的确切指标,则必须:

  1. 下载实验目录。 您不需要下载所有模型检查点,只需要最后一个(model_2695.pt,除了ELAS和LLAMAS上的实验除外)。
  2. 在实验目录内的config.yaml文件中修改所有与路径相关的字段(即,dataset paths和exps_dir)。
  3. 将下载的实验目录移至您的exps_dir文件夹。

然后,运行

python test.py --exp_name $exp_name --cfg $exps_dir/$exp_name/config.yaml --epoch 2695

$ exp_name替换为您下载的目录的名称(实验的名称),并将$ exps_dir替换为您在config.yaml文件中定义的exps_dir值。 该脚本将查找名为$ exps_dir / $ exp_name / models的目录以加载模型。

代码解读

官方test.py代码如下:

import os
import sys
import random
import logging
import argparse # 关于argparse模块的用法:https://docs.python.org/zh-cn/dev/library/argparse.html
import subprocess
from time import time

import cv2
import numpy as np
import torch

from lib.config import Config
from utils.evaluator import Evaluator


def test(model, test_loader, evaluator, exp_root, cfg, view, epoch, max_batches=None, verbose=True):
    if verbose:
        logging.info("Starting testing.")

    # Test the model
    if epoch > 0:
        model.load_state_dict(torch.load(os.path.join(exp_root, "models", "model_{:03d}.pt".format(epoch)))['model'])

    model.eval()
    criterion_parameters = cfg.get_loss_parameters()
    test_parameters = cfg.get_test_parameters()
    criterion = model.loss
    loss = 0
    total_iters = 0
    test_t0 = time()
    loss_dict = {}
    with torch.no_grad():
        for idx, (images, labels, img_idxs) in enumerate(test_loader):
            if max_batches is not None and idx >= max_batches:
                break
            if idx % 1 == 0 and verbose:
                logging.info("Testing iteration: {}/{}".format(idx + 1, len(test_loader)))
            images = images.to(device)
            labels = labels.to(device)

            t0 = time()
            outputs = model(images)
            t = time() - t0
            loss_i, loss_dict_i = criterion(outputs, labels, **criterion_parameters)
            loss += loss_i.item()
            total_iters += 1
            for key in loss_dict_i:
                if key not in loss_dict:
                    loss_dict[key] = 0
                loss_dict[key] += loss_dict_i[key]

            outputs = model.decode(outputs, labels, **test_parameters)

            if evaluator is not None:
                lane_outputs, _ = outputs
                evaluator.add_prediction(img_idxs, lane_outputs.cpu().numpy(), t / images.shape[0])
            if view:
                outputs, extra_outputs = outputs
                preds = test_loader.dataset.draw_annotation(
                    idx,
                    pred=outputs[0].cpu().numpy(),
                    cls_pred=extra_outputs[0].cpu().numpy() if extra_outputs is not None else None)
                cv2.imshow('pred', preds)
                cv2.waitKey(0)

    if verbose:
        logging.info("Testing time: {:.4f}".format(time() - test_t0))
    out_line = []
    for key in loss_dict:
        loss_dict[key] /= total_iters
        out_line.append('{}: {:.4f}'.format(key, loss_dict[key]))
    if verbose:
        logging.info(', '.join(out_line))

    return evaluator, loss / total_iters


def parse_args():
    parser = argparse.ArgumentParser(description="Lane regression")
    parser.add_argument("--exp_name", default="default", help="Experiment name", required=True)
    parser.add_argument("--cfg", default="config.yaml", help="Config file", required=True)
    parser.add_argument("--epoch", type=int, default=None, help="Epoch to test the model on")
    parser.add_argument("--batch_size", type=int, help="Number of images per batch")
    parser.add_argument("--view", action="store_true", help="Show predictions")

    return parser.parse_args()


def get_code_state():
    state = "Git hash: {}".format(
        subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
    state += '\n*************\nGit diff:\n*************\n'
    state += subprocess.run(['git', 'diff'], stdout=subprocess.PIPE).stdout.decode('utf-8')

    return state


def log_on_exception(exc_type, exc_value, exc_traceback):
    logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))


if __name__ == "__main__":
    args = parse_args()
    cfg = Config(args.cfg)

    # Set up seeds
    torch.manual_seed(cfg['seed'])
    np.random.seed(cfg['seed'])
    random.seed(cfg['seed'])

    # Set up logging
    exp_root = os.path.join(cfg['exps_dir'], os.path.basename(os.path.normpath(args.exp_name)))
    logging.basicConfig(
        format="[%(asctime)s] [%(levelname)s] %(message)s",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
            logging.StreamHandler(),
        ],
    )

    sys.excepthook = log_on_exception

    logging.info("Experiment name: {}".format(args.exp_name))
    logging.info("Config:\n" + str(cfg))
    logging.info("Args:\n" + str(args))

    # Device configuration
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Hyper parameters
    num_epochs = cfg["epochs"]
    batch_size = cfg["batch_size"] if args.batch_size is None else args.batch_size

    # Model
    model = cfg.get_model().to(device)
    test_epoch = args.epoch

    # Get data set
    test_dataset = cfg.get_dataset("test")

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size if args.view is False else 1,
                                              shuffle=False,
                                              num_workers=8)
    # Eval results
    evaluator = Evaluator(test_loader.dataset, exp_root)

    logging.basicConfig(
        format="[%(asctime)s] [%(levelname)s] %(message)s",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
            logging.StreamHandler(),
        ],
    )
    logging.info('Code state:\n {}'.format(get_code_state()))
    _, mean_loss = test(model, test_loader, evaluator, exp_root, cfg, epoch=test_epoch, view=args.view)
    logging.info("Mean test loss: {:.4f}".format(mean_loss))

    evaluator.exp_name = args.exp_name

    eval_str, _ = evaluator.eval(label='{}_{}'.format(os.path.basename(args.exp_name), test_epoch))

    logging.info(eval_str)

为了实现更好的对test.py的理解和调试,我对其修改如下:

import os
import sys
import random
import logging
import argparse
import subprocess
from time import time

import cv2
import numpy as np
import torch

from lib.config import Config
from utils.evaluator import Evaluator


def test(model, test_loader, evaluator, exp_root, cfg, view, epoch, max_batches=None, verbose=True):
# model:模型结构
# test_loader:{DataLoader:23}
# evaluator:
# exp_root:'experiments/default'
# cfg:yaml配置文件
# view:False
# epoch:2695
# max_batches:None
# verbose:True

    if verbose:
        logging.info("Starting testing.")

    # Test the model
    if epoch > 0:
        model.load_state_dict(torch.load(os.path.join(exp_root, "models", "model_{:03d}.pt".format(epoch)))['model']) # 载入模型。根据报错信息去修改

    model.eval() # 告诉模型在推理
    criterion_parameters = cfg.get_loss_parameters() # {'conf_weight': 1, 'lower_weight': 1, 'upper_weight': 1, 'cls_weight': 0, 'poly_weight': 300}
    test_parameters = cfg.get_test_parameters() # {'conf_threshold': 0.5}
    criterion = model.loss # 调用model的loss函数。用于后续的loss计算
    loss = 0
    total_iters = 0
    test_t0 = time() # 记录测试开始时间
    loss_dict = {}
    with torch.no_grad(): # 好像是被warp起来的部分不进行梯度更新,但是还是参与了计算的(深度学习基础知识,得补)
    # https://blog.csdn.net/weixin_46559271/article/details/105658654
    # https://www.jianshu.com/p/1cea017f5d11
    # https://blog.csdn.net/weixin_44134757/article/details/105775027
        for idx, (images, labels, img_idxs) in enumerate(test_loader): # 关于enumerate:https://www.runoob.com/python/python-func-enumerate.html
        # test_loader:
        # test_loader.dataset:
        # test34 lane_dataset222 tusimple137 lane_dataset203 (为什么是这样一个执行顺序,我想是调用enumerate函数时,需要先知道test_loader的长度,然后再用[]方法去取值
        # 先执行enumerate,返回一个enumerate对象。
        # idx是一个batch的索引(一个batch是16,这里358//16=23。)
        # images:{Tensor:16}
        # labels:{Tensor:16}
        # img_idxs是图片的索引。
            if max_batches is not None and idx >= max_batches:
                break
            if idx % 1 == 0 and verbose:
                logging.info("Testing iteration: {}/{}".format(idx + 1, len(test_loader)))
            images = images.to(device)
            labels = labels.to(device) # 放到GPU上

            t0 = time()
            outputs = model(images) # 将图像输入模型,记录输出 虽然{tuple:2}。但是16x35(16张图片,每张图片35个输出)
            # 一个例子如下:
            # (tensor([[ 1.9717e+01,  3.5665e-01,  5.9289e-01,  6.0696e-01, -1.2504e+00,
         -9.5759e-01,  8.3573e-01,  1.9728e+01,  1.7501e-02,  8.4806e-01,
          3.0594e-01, -5.8771e-01, -3.6744e-01,  6.2564e-01,  1.4671e+01,
          1.8156e-02,  8.9165e-01, -2.3603e-01,  1.1759e-01,  7.8406e-01,
          1.9520e-01,  5.1091e-02, -3.0347e-02,  6.6394e-01, -8.3053e-01,
          8.3323e-01,  1.3598e+00, -2.2746e-02, -1.1667e+01,  2.9069e-02,
          5.9165e-01, -1.0211e-01,  7.8156e-01,  9.9158e-01,  1.5408e-01],
        [ 2.0524e+01,  3.5466e-01,  5.9829e-01,  9.8226e-01, -1.1990e+00,
         -1.4157e+00,  1.0521e+00,  2.0533e+01, -5.5082e-04,  1.0035e+00,
          1.8302e-01, -2.6475e-01, -5.1310e-01,  6.9090e-01,  1.3714e+01,
          3.1270e-02,  9.0050e-01, -2.2248e-01,  4.2548e-01,  5.4658e-01,
          3.0025e-01, -5.9779e+00, -2.1180e-02,  5.4276e-01, -5.6866e-01,
          1.2482e+00,  1.2983e+00,  2.0630e-02, -1.2991e+01,  1.8783e-02,
          5.1402e-01,  1.3041e-01,  9.1163e-01,  1.0488e+00,  2.3146e-01],
        [ 1.9811e+01,  3.4818e-01,  6.0123e-01,  1.2543e+00, -1.1373e+00,
         -1.9976e+00,  1.3299e+00,  1.9818e+01, -1.3726e-02,  9.8255e-01,
          1.6025e-01, -4.9490e-04, -9.2998e-01,  8.9230e-01,  9.5679e+00,
          3.7305e-02,  9.9070e-01, -1.2178e-01,  5.5367e-01, -1.1417e-01,
          5.9704e-01, -1.2919e+01, -1.1187e-02,  5.6699e-01,  2.8218e-01,
          9.4523e-01,  2.8753e-01,  4.6897e-01, -1.2240e+01,  6.4192e-03,
          5.0729e-01,  4.6846e-01,  4.4547e-01,  4.8933e-01,  5.2881e-01],
        [ 1.8849e+01,  3.5481e-01,  5.9489e-01,  1.1236e+00, -1.0921e+00,
         -1.9187e+00,  1.2948e+00,  1.8853e+01, -8.1959e-03,  9.8278e-01,
          1.4345e-01, -1.6455e-02, -9.2424e-01,  8.8395e-01,  9.0486e+00,
          3.7280e-02,  1.0009e+00, -6.9186e-02,  5.1333e-01, -1.8387e-01,
          6.1838e-01, -6.3384e+00, -7.4000e-03,  6.0970e-01,  3.7614e-01,
          7.7696e-01,  1.2253e-01,  5.2114e-01, -1.0654e+01,  2.0078e-03,
          5.2485e-01,  3.7526e-01,  4.0945e-01,  4.7751e-01,  5.0353e-01],
        [ 1.8772e+01,  3.2760e-01,  5.7178e-01,  1.3995e+00, -1.1920e+00,
         -2.1902e+00,  1.3575e+00,  1.8774e+01, -8.2448e-03,  9.6515e-01,
          2.2256e-01, -1.8641e-02, -1.1020e+00,  9.3922e-01,  8.3400e+00,
          3.5112e-02,  9.8313e-01, -3.7752e-02,  5.1563e-01, -3.2519e-01,
          6.7639e-01, -9.9023e+00, -7.2143e-03,  6.2546e-01,  5.8917e-01,
          6.4413e-01, -1.6515e-01,  6.6375e-01, -7.5666e+00,  1.9182e-03,
          5.3885e-01,  4.7531e-01,  2.1594e-01,  2.6025e-01,  6.4260e-01],
        [ 1.8206e+01,  3.2264e-01,  6.0505e-01,  1.1304e+00, -9.7306e-01,
         -1.8966e+00,  1.2448e+00,  1.8205e+01, -1.2127e-02,  9.9538e-01,
          1.3683e-01,  8.2247e-02, -9.3255e-01,  8.7522e-01,  7.1296e+00,
          3.7168e-02,  9.7917e-01, -7.7272e-02,  6.0907e-01, -2.0614e-01,
          6.3528e-01, -1.3433e+01, -8.9997e-03,  5.7927e-01,  5.7465e-01,
          7.5690e-01, -4.6105e-02,  6.2752e-01, -9.0991e+00,  2.3277e-03,
          5.1762e-01,  5.5586e-01,  2.6600e-01,  2.8972e-01,  6.5162e-01],
        [ 1.8724e+01,  3.1786e-01,  5.7596e-01,  1.3195e+00, -1.0658e+00,
         -2.0516e+00,  1.2803e+00,  1.8728e+01, -1.3836e-02,  9.7149e-01,
          1.7835e-01, -1.4811e-03, -1.0227e+00,  9.0505e-01,  7.4657e+00,
          3.5839e-02,  9.9364e-01, -8.7356e-02,  5.4268e-01, -2.3358e-01,
          6.4212e-01, -1.4904e+01, -7.8026e-03,  6.0794e-01,  5.7523e-01,
          6.7591e-01, -1.1579e-01,  6.5063e-01, -1.0407e+01,  1.2761e-03,
          5.3218e-01,  6.0676e-01,  1.9660e-01,  1.8991e-01,  6.7666e-01],
        [ 1.9729e+01,  3.8188e-01,  1.0091e+00,  2.2574e-01, -3.9591e-01,
         -3.3652e-01,  6.8122e-01,  1.9735e+01,  1.4888e-02,  9.7211e-01,
         -1.8558e-01,  3.4326e-01,  5.1835e-01,  3.1772e-01,  1.1779e+01,
          1.6677e-02,  6.0190e-01, -8.7055e-01,  9.8156e-01,  1.5660e+00,
         -9.3855e-02, -7.9949e+00, -1.7439e-02,  5.0544e-01, -5.9118e-01,
          1.1279e+00,  1.4808e+00, -4.8462e-03, -1.1008e+01,  1.4023e-02,
          4.9886e-01,  1.2144e-01,  9.3424e-01,  1.1006e+00,  1.8501e-01],
        [ 2.0828e+01,  3.7389e-01,  9.5909e-01,  4.0849e-01, -6.5877e-01,
         -5.4334e-01,  7.9542e-01,  2.0835e+01,  1.8135e-02,  1.0151e+00,
         -9.3961e-02,  1.8864e-01,  3.3228e-01,  4.2197e-01,  1.4412e+01,
          1.3447e-02,  5.9736e-01, -8.0056e-01,  8.6799e-01,  1.4857e+00,
         -2.8108e-02,  2.2950e+00, -1.7361e-02,  4.9760e-01, -9.8458e-01,
          1.5259e+00,  1.9876e+00, -2.2419e-01, -9.7860e-01,  1.4131e-02,
          4.5145e-01, -2.6719e-01,  1.6454e+00,  1.9343e+00, -1.7758e-01],
        [ 1.8541e+01,  3.6052e-01,  9.8302e-01,  1.2630e-01, -3.6989e-01,
         -4.0694e-01,  6.9298e-01,  1.8542e+01,  1.2207e-02,  1.0018e+00,
         -1.6262e-01,  3.3378e-01,  3.3623e-01,  4.0727e-01,  1.0124e+01,
          1.9962e-02,  5.8260e-01, -8.2206e-01,  9.8333e-01,  1.3858e+00,
          3.0251e-02, -6.4354e+00, -1.2738e-02,  4.9564e-01, -4.4351e-01,
          1.0422e+00,  1.2766e+00,  1.4124e-01, -5.1403e+00,  1.1081e-02,
          4.7627e-01,  6.2715e-02,  9.4187e-01,  1.1648e+00,  2.3920e-01],
        [ 2.1700e+01,  3.6963e-01,  1.0161e+00,  3.5365e-01, -5.1549e-01,
         -3.8895e-01,  7.3650e-01,  2.1710e+01,  1.2069e-02,  9.6929e-01,
         -1.0482e-01,  2.9372e-01,  4.4866e-01,  3.7781e-01,  1.8020e+01,
          1.0207e-02,  5.3495e-01, -8.2141e-01,  1.0396e+00,  1.6766e+00,
         -1.0593e-01,  6.8139e+00, -2.0878e-02,  4.6487e-01, -1.3351e+00,
          1.8629e+00,  2.4661e+00, -4.0683e-01, -3.9403e+00,  8.8815e-03,
          4.3651e-01, -3.1911e-01,  2.0130e+00,  2.3060e+00, -3.2863e-01],
        [ 1.9982e+01,  3.5527e-01,  9.9771e-01,  1.8310e-01, -3.2837e-01,
         -2.6487e-01,  6.1877e-01,  1.9988e+01,  1.3263e-02,  9.2778e-01,
         -2.2666e-01,  3.8196e-01,  6.1604e-01,  2.7222e-01,  1.1386e+01,
          2.2135e-02,  5.5518e-01, -1.1819e+00,  1.1790e+00,  1.9540e+00,
         -2.2913e-01, -8.7199e+00, -1.3203e-02,  5.0026e-01, -6.2586e-01,
          1.1869e+00,  1.5888e+00, -8.7219e-03, -9.9509e+00,  1.3846e-02,
          4.7484e-01,  1.0665e-01,  1.0456e+00,  1.2277e+00,  1.8260e-01],
        [ 1.9321e+01,  3.5748e-01,  1.0011e+00,  8.5000e-02, -3.1382e-01,
         -5.8064e-02,  5.0936e-01,  1.9330e+01,  1.5940e-02,  9.3347e-01,
         -2.7937e-01,  3.4139e-01,  7.9552e-01,  1.9267e-01,  1.1115e+01,
          1.9466e-02,  5.5904e-01, -1.1650e+00,  1.0820e+00,  2.0302e+00,
         -2.5906e-01, -8.1233e+00, -1.4825e-02,  5.1039e-01, -6.9422e-01,
          1.1615e+00,  1.6638e+00, -7.1307e-02, -1.1717e+01,  1.7474e-02,
          4.8936e-01,  1.3536e-01,  1.0647e+00,  1.1938e+00,  1.5716e-01],
        [ 1.9986e+01,  3.4426e-01,  1.0036e+00,  6.0444e-02, -3.1636e-01,
         -5.3784e-02,  5.1158e-01,  1.9998e+01,  1.5422e-02,  9.3486e-01,
         -2.3981e-01,  3.1167e-01,  7.6142e-01,  2.0912e-01,  1.2843e+01,
          1.9240e-02,  5.4574e-01, -1.3053e+00,  1.1310e+00,  2.1718e+00,
         -3.1241e-01, -9.2920e+00, -1.1794e-02,  5.1373e-01, -7.2432e-01,
          1.2056e+00,  1.7307e+00, -8.6010e-02, -1.2046e+01,  1.8683e-02,
          4.7902e-01,  9.8536e-02,  1.0589e+00,  1.2306e+00,  1.6226e-01],
        [ 1.9049e+01,  3.6031e-01,  9.8259e-01,  1.7886e-01, -3.9499e-01,
         -3.9496e-01,  6.7416e-01,  1.9055e+01,  1.4615e-02,  9.9518e-01,
         -1.4237e-01,  3.1150e-01,  3.9308e-01,  3.6029e-01,  1.0483e+01,
          2.2869e-02,  6.1555e-01, -8.7037e-01,  9.7719e-01,  1.4753e+00,
         -4.7680e-02, -1.0114e+01, -1.3152e-02,  5.2934e-01, -3.8377e-01,
          1.0160e+00,  1.2257e+00,  1.1184e-01, -9.7097e+00,  1.3663e-02,
          4.9965e-01,  2.1355e-01,  8.3867e-01,  9.5725e-01,  2.7290e-01],
        [ 2.0449e+01,  3.6672e-01,  9.9850e-01,  2.1467e-01, -3.9144e-01,
         -3.1768e-01,  6.5340e-01,  2.0459e+01,  1.5070e-02,  9.6099e-01,
         -1.6301e-01,  3.2221e-01,  5.1998e-01,  3.1643e-01,  1.2633e+01,
          2.5129e-02,  5.4299e-01, -1.2838e+00,  1.2315e+00,  2.0434e+00,
         -2.6675e-01, -9.3589e+00, -9.6726e-03,  5.0069e-01, -5.9143e-01,
          1.2203e+00,  1.5694e+00, -2.2036e-02, -1.2009e+01,  1.6047e-02,
          4.7439e-01,  1.4823e-01,  1.0121e+00,  1.1582e+00,  2.1145e-01]],
       device='cuda:0'), None)
            t = time() - t0
            loss_i, loss_dict_i = criterion(outputs, labels, **criterion_parameters) # output:{16, 35}.labels:{16, 5, 115} 
            # **criterion_parameters:{'conf_weight': 1, 'lower_weight': 1, 'upper_weight': 1, 'cls_weight': 0, 'poly_weight': 300}
# loss_i:tensor(3.9538, device='cuda:0')
# loss_dict_i:{'conf': tensor(0.2224, device='cuda:0'), 'lower': tensor(0.0003, device='cuda:0'), 'upper': tensor(0.0023, device='cuda:0'), 'poly': tensor(3.7288, device='cuda:0'), 'cls_loss': 0}
            loss += loss_i.item() # 记录epoch的损失
            total_iters += 1 # iter次数
            for key in loss_dict_i:
                if key not in loss_dict:
                    loss_dict[key] = 0
                loss_dict[key] += loss_dict_i[key] # 记录损失

            outputs = model.decode(outputs, labels, **test_parameters) # 调用函数

            if evaluator is not None:
                lane_outputs, _ = outputs # torch.Size([16, 5, 7])
                evaluator.add_prediction(img_idxs, lane_outputs.cpu().numpy(), t / images.shape[0]) # 图片的索引。img_idxs:tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
                # t:模型将图片作为输入后,预测输出所需时间 0.7983336448669434
                # images:torch.Size([16, 3, 360, 640])
            if view:
                outputs, extra_outputs = outputs
                preds = test_loader.dataset.draw_annotation(
                    idx,
                    pred=outputs[0].cpu().numpy(),
                    cls_pred=extra_outputs[0].cpu().numpy() if extra_outputs is not None else None)
                cv2.imshow('pred', preds)
                cv2.waitKey(0)

    if verbose:
        logging.info("Testing time: {:.4f}".format(time() - test_t0))
    out_line = []
    for key in loss_dict:
        loss_dict[key] /= total_iters
        out_line.append('{}: {:.4f}'.format(key, loss_dict[key]))
    if verbose:
        logging.info(', '.join(out_line))

    return evaluator, loss / total_iters


def parse_args():
    parser = argparse.ArgumentParser(description="Lane regression") # 1.创建一个解析器
    parser.add_argument("--exp_name", default="default", help="Experiment name") # 2.这个和下面几个都是添加参数的。关于argparse模块的使用单独学习过,这里不再赘述。
    parser.add_argument("--cfg", default="tusimple.yaml", help="Config file") # 配置文件,这里也有修改
    parser.add_argument("--epoch", type=int, default=2695, help="Epoch to test the model on") # epoch,同样修改
    parser.add_argument("--batch_size", type=int, help="Number of images per batch")
    parser.add_argument("--view", action="store_true", help="Show predictions")

    return parser.parse_args()


def get_code_state():
    state = "Git hash: {}".format(
        subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
    state += '\n*************\nGit diff:\n*************\n'
    state += subprocess.run(['git', 'diff'], stdout=subprocess.PIPE).stdout.decode('utf-8')

    return state


if __name__ == "__main__": # 从函数入口开始调试
    args = parse_args() # args用来接收parse_args的返回值
    cfg = Config(args.cfg) # args.cfg='tusimple.yaml'。将其传入Config

    # Set up seeds
    torch.manual_seed(cfg['seed']) # 这几个都是用来设置随机数种子。
    np.random.seed(cfg['seed']) # 由于Config类中设置了__getitem__魔术方法,所以可以用这种方法去读取
    random.seed(cfg['seed'])
    
    # Set up logging # 设置日志文件
    exp_root = os.path.join(cfg['exps_dir'], os.path.basename(os.path.normpath(args.exp_name))) # exp_root='experiments/default'
    logging.basicConfig( # logging模块中最简单的用法。附近博文也有单独写过。
        format="[%(asctime)s] [%(levelname)s] %(message)s",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
            logging.StreamHandler(),
        ],
    )

    sys.excepthook = log_on_exception # 将log_on_exception函数的返回值赋值给sys.excepthook。(这个模块一直不太懂,有缘再学)

    logging.info("Experiment name: {}".format(args.exp_name)) # 设置日志信息。
    logging.info("Config:\n" + str(cfg))
    logging.info("Args:\n" + str(args))

    # Device configuration
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 配置GPU。这里还要再搞一下,之前可以用的,现在可能软件版本什么的不兼容了。

    # Hyper parameters
    num_epochs = cfg["epochs"] # 2695
    batch_size = cfg["batch_size"] if args.batch_size is None else args.batch_size # 16。这两个超参数都在yaml文件中配置过了。
    
    # Model
    model = cfg.get_model().to(device) # 调用Config类中的get_model()方法。
    test_epoch = args.epoch # 2695

    # Get data set
    test_dataset = cfg.get_dataset("test") # {LaneDataset:358}

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size if args.view is False else 1,
                                              shuffle=False,
                                              num_workers=8) # {DataLoader:23}
    # Eval results
    evaluator = Evaluator(test_loader.dataset, exp_root) # exp_root='experiments/default'。实例化一个类
    
    logging.basicConfig( # logging模块的使用
        format="[%(asctime)s] [%(levelname)s] %(message)s",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
            logging.StreamHandler(),
        ],
    )
    logging.info('Code state:\n {}'.format(get_code_state())) # 同样是日志文件。Code state:
# Git hash: 
# *************
# Git diff:
# *************
    _, mean_loss = test(model, test_loader, evaluator, exp_root, cfg, epoch=test_epoch, view=args.view) # 
    logging.info("Mean test loss: {:.4f}".format(mean_loss)) # 5.2436

    evaluator.exp_name = args.exp_name

    eval_str, _ = evaluator.eval(label='{}_{}'.format(os.path.basename(args.exp_name), test_epoch)) # https://www.cnblogs.com/baxianhua/p/10214263.html
    

    logging.info(eval_str)

上述文件配置的tusimple.yaml如下:

# Training settings
exps_dir: 'experiments'
iter_log_interval: 1
iter_time_window: 100
model_save_interval: 1
seed: 1
backup:
model:
  name: PolyRegression
  parameters:
    num_outputs: 35 # (5 lanes) * (1 conf + 2 (upper & lower) + 4 poly coeffs)
    pretrained: true
    backbone: 'efficientnet-b0'
    pred_category: false
    curriculum_steps: [0, 0, 0, 0]
loss_parameters:
  conf_weight: 1 # Wc
  lower_weight: 1 # Ws
  upper_weight: 1 # Wh
  cls_weight: 0
  poly_weight: 300 # Wp
batch_size: 16
epochs: 2695
optimizer:
  name: Adam
  parameters:
    lr: 3.0e-4
lr_scheduler:
  name: CosineAnnealingLR
  parameters:
    T_max: 385

# Testing settings
test_parameters:
  conf_threshold: 0.5

# Dataset settings
datasets:
  train:
    type: LaneDataset
    parameters:
      dataset: tusimple
      split: train
      img_size: [360, 640]
      normalize: true
      aug_chance: 0.9090909090909091 # 10/11
      augmentations:
       - name: Affine
         parameters:
           rotate: !!python/tuple [-10, 10]
       - name: HorizontalFlip
         parameters:
           p: 0.5
       - name: CropToFixedSize
         parameters:
           width: 1152
           height: 648
      root: "/home/wqf/PolyLaneNet/PolyLaneNet-master"

  test: &test
    type: LaneDataset
    parameters:
      dataset: tusimple
      split: val
      max_lanes: 5
      img_size: [360, 640]
      root: "/home/wqf/PolyLaneNet/PolyLaneNet-master"
      normalize: true
      augmentations: []

  # val = test
  val:
    <<: *test

上述文件引入的config.py文件的解析如下:

import yaml # 关于PyYAML模块,下载时叫PyYAML,import时则直接叫yaml
import torch

import lib.models as models
import lib.datasets as datasets


class Config(object):
    def __init__(self, config_path): # config_path='tusimple.yaml'
        self.config = {}
        self.load(config_path) # 初始化时调用load函数

    def load(self, path):
        with open(path, 'r') as file: # with...as...上下文管理器
            self.config_str = file.read() # 读取文件
        self.config = yaml.load(self.config_str, Loader=yaml.FullLoader) # 载入yaml文件。关于Loader这个参数:https://blog.csdn.net/ly021499/article/details/89026860。(就是更加安全的载入吧)

    def __repr__(self):
        return self.config_str

    def get_dataset(self, split):
        return getattr(datasets,
                       self.config['datasets'][split]['type'])(**self.config['datasets'][split]['parameters']) # 返回一个lib.datasets.lane_dataset.LaneDataset object
                       # 这里相当于调用datasets.self.config['datasets'][split]['type']。即调用datasets.LaneDataset类
                       # (**self.config['datasets'][split]['parameters'])):{'dataset': 'tusimple', 'split': 'val', 'max_lanes': 5, 'img_size': [360, 640], 'root': '/home/wqf/PolyLaneNet/PolyLaneNet-master', 'normalize': True, 'augmentations': []} # 至于这里的split
                       # 为什么是val。我也不太懂,源码中就这样写的。先这样读下去吧。

    def get_model(self):
        name = self.config['model']['name'] # 'PolyRegression'
        parameters = self.config['model']['parameters'] # {'num_outputs': 35, 'pretrained': True, 'backbone': 'efficientnet-b0', 'pred_category': False, 'curriculum_steps': [0, 0, 0, 0]}
        return getattr(models, name)(**parameters) # getattr() 函数用于返回一个对象属性值。调用models.PolyRegression。
        # getattr(object, name[, default]) -> value
        # 从对象获取命名属性;getattr(x,'y')等价于x.y。
# 当给定默认参数时,当属性不存在时返回该参数;如果没有该参数,则在这种情况下引发异常。
# (**parameters)应该是将其他参数一起传递

    def get_optimizer(self, model_parameters):
        return getattr(torch.optim, self.config['optimizer']['name'])(model_parameters,
                                                                      **self.config['optimizer']['parameters'])

    def get_lr_scheduler(self, optimizer):
        return getattr(torch.optim.lr_scheduler,
                       self.config['lr_scheduler']['name'])(optimizer, **self.config['lr_scheduler']['parameters'])

    def get_loss_parameters(self):
        return self.config['loss_parameters']

    def get_test_parameters(self):
        return self.config['test_parameters']

    def __getitem__(self, item):
        return self.config[item]

上述模块用到的lib.models解读如下:

import torch
import torch.nn as nn
from torchvision.models import resnet34, resnet50, resnet101
from efficientnet_pytorch import EfficientNet # pytorch中为efficientnet专门写好的网络模型


class OutputLayer(nn.Module): # 这个类就是说如果有额外的输出的话,则在EfficientNet的fc全连接层之后再加一层
    def __init__(self, fc, num_extra):
        super(OutputLayer, self).__init__() # 调用父方法,初始化
        self.regular_outputs_layer = fc # Linear(in_features=1280, out_features=35, bias=True)
        self.num_extra = num_extra # 0
        if num_extra > 0:
            self.extra_outputs_layer = nn.Linear(fc.in_features, num_extra)

    def forward(self, x):
        regular_outputs = self.regular_outputs_layer(x)
        if self.num_extra > 0:
            extra_outputs = self.extra_outputs_layer(x)
        else:
            extra_outputs = None

        return regular_outputs, extra_outputs


class PolyRegression(nn.Module): # 
    def __init__(self,
                 num_outputs, # 35
                 backbone, # 'efficientnet-b0'
                 pretrained, # True
                 curriculum_steps=None, # [0, 0, 0, 0]
                 extra_outputs=0, # 0
                 share_top_y=True, # True
                 pred_category=False): # False
        super(PolyRegression, self).__init__() # 调用父方法
        if 'efficientnet' in backbone: 
            if pretrained:
                self.model = EfficientNet.from_pretrained(backbone, num_classes=num_outputs) # 第一次用这里会下载预训练好的模型。加载预训练好的EfficientNet
            else:
                self.model = EfficientNet.from_name(backbone, override_params={'num_classes': num_outputs}) # 加载EfficientNet(只是网络结构,无预训练参数)
            self.model._fc = OutputLayer(self.model._fc, extra_outputs) # 修改EfficientNet的全连接层。调用OutputLayer类。(_fc): Linear(in_features=1280, out_features=35, bias=True)
        elif backbone == 'resnet34':
            self.model = resnet34(pretrained=pretrained)
            self.model.fc = nn.Linear(self.model.fc.in_features, num_outputs)
            self.model.fc = OutputLayer(self.model.fc, extra_outputs)
        elif backbone == 'resnet50':
            self.model = resnet50(pretrained=pretrained)
            self.model.fc = nn.Linear(self.model.fc.in_features, num_outputs)
            self.model.fc = OutputLayer(self.model.fc, extra_outputs)
        elif backbone == 'resnet101':
            self.model = resnet101(pretrained=pretrained)
            self.model.fc = nn.Linear(self.model.fc.in_features, num_outputs)
            self.model.fc = OutputLayer(self.model.fc, extra_outputs)
        else:
            raise NotImplementedError()

        self.curriculum_steps = [0, 0, 0, 0] if curriculum_steps is None else curriculum_steps # [0, 0, 0, 0]。这个值目前不知道干啥用。
        self.share_top_y = share_top_y # True
        self.extra_outputs = extra_outputs # 0
        self.pred_category = pred_category # False
        self.sigmoid = nn.Sigmoid() # 这里的几步相当于是承接一下传进来的参数

    def forward(self, x, epoch=None, **kwargs):
        output, extra_outputs = self.model(x, **kwargs)
        for i in range(len(self.curriculum_steps)):
            if epoch is not None and epoch < self.curriculum_steps[i]:
                output[-len(self.curriculum_steps) + i] = 0
        return output, extra_outputs

    def decode(self, all_outputs, labels, conf_threshold=0.5):
        outputs, extra_outputs = all_outputs # outputs:torch.Size([16, 35])
        if extra_outputs is not None: # extra_outputs=None
            extra_outputs = extra_outputs.reshape(labels.shape[0], 5, -1)
            extra_outputs = extra_outputs.argmax(dim=2)
        outputs = outputs.reshape(len(outputs), -1, 7)  # score + upper + lower + 4 coeffs = 7 # torch.Size([16, 5, 7])
        outputs[:, :, 0] = self.sigmoid(outputs[:, :, 0]) # 得分经过sigmoid函数
        outputs[outputs[:, :, 0] < conf_threshold] = 0 # conf_threshold=0.5

        if False and self.share_top_y:
            outputs[:, :, 0] = outputs[:, 0, 0].expand(outputs.shape[0], outputs.shape[1])

        return outputs, extra_outputs

    def loss(self,
             outputs, # {16 35}
             target, # {16 5 115}
             conf_weight=1, # 1
             lower_weight=1, # 1
             upper_weight=1, # 1
             cls_weight=1, # 0。不知为何要设为0
             poly_weight=300, # 300
             threshold=15 / 720.): # 0.020833333333333332 #不太懂这是哪个阈值。根据作者的回答,这对应的是等式(5)中的tloss。
        pred, extra_outputs = outputs # extra_outputs=None {16 35}
        bce = nn.BCELoss() # 实例化损失函数
        mse = nn.MSELoss()
        s = nn.Sigmoid() # 激活函数
        threshold = nn.Threshold(threshold**2, 0.) # nn.Threshold:https://samuel92.blog.csdn.net/article/details/105888269
        # Threshold(threshold=0.00043402777777777775, value=0.0)
        # 大于threshold还是原值,否则是value
        pred = pred.reshape(-1, target.shape[1], 1 + 2 + 4) # torch.Size([16, 5, 7])。
        # 16张图片,每张图片5条车道线,每条车道线7个要预测的值
        # 刚刚发现Debug窗口中变量还有shape这个属性。
        target_categories, pred_confs = target[:, :, 0].reshape((-1, 1)), s(pred[:, :, 0]).reshape((-1, 1)) # 第一第二维取所有的,第三维取[0]。取出的是预测的车道线的类别(认为是不是车道线,认为是的话则1,否则0。当时是根据车道线的数量标记的。
        # target_categories:torch.Size([80, 1])。真实类别(是否有车道线)
        # pred_confs:torch.Size([80, 1])。预测的置信度。网络输出值要经过一个sigmoid函数(为什么文中没有交代,但是对的吧)
        target_uppers, pred_uppers = target[:, :, 2].reshape((-1, 1)), pred[:, :, 2].reshape((-1, 1)) # 上限
        # target_uppers:torch.Size([80, 1])。pred_uppers:torch.Size([80, 1])
        target_points, pred_polys = target[:, :, 3:].reshape((-1, target.shape[2] - 3)), pred[:, :, 3:].reshape(-1, 4) # 真实车道点坐标和预测多项式系数
        # target_points:torch.Size([80, 112])pred_polys:torch.Size([80, 4])
        target_lowers, pred_lowers = target[:, :, 1], pred[:, :, 1] # 下限
        # target_lowers:torch.Size([16, 5]) pred_lowers:torch.Size([16, 5])
        

        if self.share_top_y:
            # inexistent lanes have -1e-5 as lower
            # i'm just setting it to a high value here so that the .min below works fine
            target_lowers[target_lowers < 0] = 1 # 上面英语是解释
            target_lowers[...] = target_lowers.min(dim=1, keepdim=True)[0] # 执行本步之前target_lowers:tensor([[0.3333, 0.3472, 0.3333, 0.3333, 1.0000],
        [0.3889, 0.3750, 0.3889, 1.0000, 1.0000],
        [0.3194, 0.3333, 0.3472, 1.0000, 1.0000],
        [0.3750, 0.3750, 0.3611, 1.0000, 1.0000],
        [0.3889, 0.3889, 0.3750, 1.0000, 1.0000],
        [0.3194, 0.3194, 0.3472, 1.0000, 1.0000],
        [0.3750, 0.3750, 0.3750, 1.0000, 1.0000],
        [0.3333, 0.3333, 0.3472, 1.0000, 1.0000],
        [0.4028, 0.4028, 0.4028, 0.4028, 1.0000],
        [0.3889, 0.3889, 0.4028, 0.3750, 1.0000],
        [0.3611, 0.3611, 0.3472, 0.3611, 1.0000],
        [0.3611, 0.3472, 0.3611, 0.3611, 1.0000],
        [0.3750, 0.3750, 0.3750, 0.3750, 1.0000],
        [0.3611, 0.3472, 0.3611, 0.3611, 1.0000],
        [0.3472, 0.3611, 0.3472, 0.3472, 1.0000],
        [0.3472, 0.3611, 0.3611, 1.0000, 1.0000]], device='cuda:0')
        # 执行之后:
        tensor([[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],
        [0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
        [0.3194, 0.3194, 0.3194, 0.3194, 0.3194],
        [0.3611, 0.3611, 0.3611, 0.3611, 0.3611],
        [0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
        [0.3194, 0.3194, 0.3194, 0.3194, 0.3194],
        [0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
        [0.3333, 0.3333, 0.3333, 0.3333, 0.3333],
        [0.4028, 0.4028, 0.4028, 0.4028, 0.4028],
        [0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
        [0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
        [0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
        [0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
        [0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
        [0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
        [0.3472, 0.3472, 0.3472, 0.3472, 0.3472]], device='cuda:0')
        # 相当于把每张图片中(真实)最低的车道线下限找出来,赋值给该图片其他车道线
            pred_lowers[...] = pred_lowers[:, 0].reshape(-1, 1).expand(pred.shape[0], pred.shape[1])
            # 相当于是把预测出的下限(第一条赋值给其他)
            # 执行前:
            # tensor([[ 0.3535,  0.0249,  0.0093, -0.0395,  0.0213],
        [ 0.3805,  0.0134,  0.0160, -0.0178,  0.0140],
        [ 0.3622,  0.0081,  0.0220, -0.0137,  0.0179],
        [ 0.3494,  0.0137,  0.0083, -0.0081, -0.0029],
        [ 0.3700,  0.0095,  0.0172, -0.0104,  0.0139],
        [ 0.3412,  0.0227,  0.0132, -0.0182,  0.0284],
        [ 0.3597,  0.0178,  0.0105, -0.0178,  0.0143],
        [ 0.3421,  0.0309, -0.0078, -0.0274,  0.0260],
        [ 0.3914,  0.0276,  0.0022, -0.0368,  0.0225],
        [ 0.3634,  0.0131,  0.0176, -0.0090,  0.0079],
        [ 0.3482,  0.0251,  0.0068, -0.0437,  0.0261],
        [ 0.3638,  0.0144,  0.0193, -0.0222,  0.0148],
        [ 0.3742,  0.0059,  0.0230, -0.0074,  0.0050],
        [ 0.3800,  0.0265,  0.0094, -0.0416,  0.0269],
        [ 0.3521,  0.0156,  0.0247, -0.0218,  0.0224],
        [ 0.3536,  0.0077,  0.0365, -0.0382,  0.0281]], device='cuda:0')
        # 执行后
        # tensor([[0.3535, 0.3535, 0.3535, 0.3535, 0.3535],
        [0.3805, 0.3805, 0.3805, 0.3805, 0.3805],
        [0.3622, 0.3622, 0.3622, 0.3622, 0.3622],
        [0.3494, 0.3494, 0.3494, 0.3494, 0.3494],
        [0.3700, 0.3700, 0.3700, 0.3700, 0.3700],
        [0.3412, 0.3412, 0.3412, 0.3412, 0.3412],
        [0.3597, 0.3597, 0.3597, 0.3597, 0.3597],
        [0.3421, 0.3421, 0.3421, 0.3421, 0.3421],
        [0.3914, 0.3914, 0.3914, 0.3914, 0.3914],
        [0.3634, 0.3634, 0.3634, 0.3634, 0.3634],
        [0.3482, 0.3482, 0.3482, 0.3482, 0.3482],
        [0.3638, 0.3638, 0.3638, 0.3638, 0.3638],
        [0.3742, 0.3742, 0.3742, 0.3742, 0.3742],
        [0.3800, 0.3800, 0.3800, 0.3800, 0.3800],
        [0.3521, 0.3521, 0.3521, 0.3521, 0.3521],
        [0.3536, 0.3536, 0.3536, 0.3536, 0.3536]], device='cuda:0')

        target_lowers = target_lowers.reshape((-1, 1)) # torch.Size([80, 1])
        pred_lowers = pred_lowers.reshape((-1, 1)) # torch.Size([80, 1])

        target_confs = (target_categories > 0).float()
        valid_lanes_idx = target_confs == 1 # 记录有效的车道线的位置
        valid_lanes_idx_flat = valid_lanes_idx.reshape(-1) 
        lower_loss = mse(target_lowers[valid_lanes_idx], pred_lowers[valid_lanes_idx]) # 计算mse误差(上限)
        upper_loss = mse(target_uppers[valid_lanes_idx], pred_uppers[valid_lanes_idx]) # 下限

        # classification loss
        if self.pred_category and self.extra_outputs > 0: # 预测类别,额外输出有时才进入。(这显然不满足)
            ce = nn.CrossEntropyLoss()
            pred_categories = extra_outputs.reshape(target.shape[0] * target.shape[1], -1)
            target_categories = target_categories.reshape(pred_categories.shape[:-1]).long()
            pred_categories = pred_categories[target_categories > 0]
            target_categories = target_categories[target_categories > 0]
            cls_loss = ce(pred_categories, target_categories - 1)
        else:
            cls_loss = 0 # 分类损失为0

        # poly loss calc
        target_xs = target_points[valid_lanes_idx_flat, :target_points.shape[1] // 2] # 举例如下
        # target_points:torch.Size([80, 112])
        tensor([[ 3.8984e-01,  3.7813e-01,  3.6719e-01,  ..., -1.0000e+05,         -1.0000e+05, -1.0000e+05],
        [ 4.1328e-01,  4.1484e-01,  4.1641e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 4.3203e-01,  4.4375e-01,  4.5547e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        ...,
        [ 4.0859e-01,  4.3906e-01,  4.7813e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [-1.0000e+05, -1.0000e+05, -1.0000e+05,  ..., -1.0000e+05,  1.0000e+05, -1.0000e+05],
        [-1.0000e+05, -1.0000e+05, -1.0000e+05,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05]], device='cuda:0')
        # valid_lanes_idx_flat:True,  True,  True,……,True, False, False
        # target_xs:torch.Size([56, 56])
        # 第一个56是取的valid_lanes_idx_flat为True的情况;第二个56是取的target_points中前56个点(x)
        tensor([[ 3.8984e-01,  3.7813e-01,  3.6719e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 4.1328e-01,  4.1484e-01,  4.1641e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 4.3203e-01,  4.4375e-01,  4.5547e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        ...,
        [ 3.5156e-01,  3.5547e-01,  3.5547e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 3.8438e-01,  4.0078e-01,  4.1563e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
        [ 4.0859e-01,  4.3906e-01,  4.7813e-01,  ..., -1.0000e+05, -1.0000e+05, -1.0000e+05]], device='cuda:0')
        # target_points[0, :target_points.shape[1] // 2]:
        tensor([ 3.8984e-01,  3.7813e-01,  3.6719e-01,  3.5391e-01,  3.3984e-01,         3.2656e-01,  3.1250e-01,  2.9219e-01,  2.7031e-01,  2.4844e-01,         2.2656e-01,  2.0469e-01,  1.8359e-01,  1.6172e-01,  1.3984e-01,         1.1797e-01,  9.6094e-02,  7.5000e-02,  5.3125e-02,  3.1250e-02,         9.3750e-03, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05], device='cuda:0')
        # target_points[0, :]
        tensor([ 3.8984e-01,  3.7813e-01,  3.6719e-01,  3.5391e-01,  3.3984e-01,         3.2656e-01,  3.1250e-01,  2.9219e-01,  2.7031e-01,  2.4844e-01,         2.2656e-01,  2.0469e-01,  1.8359e-01,  1.6172e-01,  1.3984e-01,         1.1797e-01,  9.6094e-02,  7.5000e-02,  5.3125e-02,  3.1250e-02,         9.3750e-03, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05,  3.3333e-01,  3.4722e-01,  3.6111e-01,  3.7500e-01,         3.8889e-01,  4.0278e-01,  4.1667e-01,  4.3056e-01,  4.4444e-01,         4.5833e-01,  4.7222e-01,  4.8611e-01,  5.0000e-01,  5.1389e-01,         5.2778e-01,  5.4167e-01,  5.5556e-01,  5.6944e-01,  5.8333e-01,         5.9722e-01,  6.1111e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05,        -1.0000e+05, -1.0000e+05], device='cuda:0')

        ys = target_points[valid_lanes_idx_flat, target_points.shape[1] // 2:].t()
        # 这里取的是y的转置(为了下面的计算)
        # tensor([[ 3.3333e-01,  3.4722e-01,  3.3333e-01,  ...,  3.4722e-01,          3.6111e-01,  3.6111e-01],        [ 3.4722e-01,  3.6111e-01,  3.4722e-01,  ...,  3.6111e-01,          3.7500e-01,  3.7500e-01],        [ 3.6111e-01,  3.7500e-01,  3.6111e-01,  ...,  3.7500e-01,          3.8889e-01,  3.8889e-01],        ...,        [-1.0000e+05, -1.0000e+05, -1.0000e+05,  ..., -1.0000e+05,         -1.0000e+05, -1.0000e+05],        [-1.0000e+05, -1.0000e+05, -1.0000e+05,  ..., -1.0000e+05,         -1.0000e+05, -1.0000e+05],        [-1.0000e+05, -1.0000e+05, -1.0000e+05,  ..., -1.0000e+05,         -1.0000e+05, -1.0000e+05]], device='cuda:0')
        valid_xs = target_xs >= 0 # 有效的x点
        pred_polys = pred_polys[valid_lanes_idx_flat] # 有效的预测多项式
        pred_xs = pred_polys[:, 0] * ys**3 + pred_polys[:, 1] * ys**2 + pred_polys[:, 2] * ys + pred_polys[:, 3] # 求预测的x坐标
        # 一个例子如下:tensor([[ 3.6480e-01,  4.1295e-01,  4.3892e-01,  ...,  3.8540e-01,
          4.1487e-01,  4.9135e-01],
        [ 3.5220e-01,  4.1199e-01,  4.5362e-01,  ...,  3.6899e-01,
          4.1176e-01,  5.0227e-01],
        [ 3.3908e-01,  4.1083e-01,  4.6825e-01,  ...,  3.5225e-01,
          4.0856e-01,  5.1314e-01],
        ...,
        [ 2.1743e+14, -7.2232e+13,  1.3829e+14,  ..., -1.0635e+14,
         -1.2338e+14,  2.1186e+14],
        [ 2.1743e+14, -7.2232e+13,  1.3829e+14,  ..., -1.0635e+14,
         -1.2338e+14,  2.1186e+14],
        [ 2.1743e+14, -7.2232e+13,  1.3829e+14,  ..., -1.0635e+14,
         -1.2338e+14,  2.1186e+14]], device='cuda:0')
        
        pred_xs.t_() # 转置
        # tensor([[ 3.6480e-01,  3.5220e-01,  3.3908e-01,  ...,  2.1743e+14,
          2.1743e+14,  2.1743e+14],
        [ 4.1295e-01,  4.1199e-01,  4.1083e-01,  ..., -7.2232e+13,
         -7.2232e+13, -7.2232e+13],
        [ 4.3892e-01,  4.5362e-01,  4.6825e-01,  ...,  1.3829e+14,
          1.3829e+14,  1.3829e+14],
        ...,
        [ 3.8540e-01,  3.6899e-01,  3.5225e-01,  ..., -1.0635e+14,
         -1.0635e+14, -1.0635e+14],
        [ 4.1487e-01,  4.1176e-01,  4.0856e-01,  ..., -1.2338e+14,
         -1.2338e+14, -1.2338e+14],
        [ 4.9135e-01,  5.0227e-01,  5.1314e-01,  ...,  2.1186e+14,
          2.1186e+14,  2.1186e+14]], device='cuda:0')
        weights = (torch.sum(valid_xs, dtype=torch.float32) / torch.sum(valid_xs, dim=1, dtype=torch.float32))**0.5 # https://blog.csdn.net/qq_39463274/article/details/105145029.
        # dim=1是按行求
        # 下面英文注释是解释
        pred_xs = (pred_xs.t_() *
                   weights).t()  # without this, lanes with more points would have more weight on the cost function
# tensor([[ 3.3652e+00,  3.2490e+00,  3.1279e+00,  ...,  2.0057e+15,
          2.0057e+15,  2.0057e+15],
        [ 2.5463e+00,  2.5404e+00,  2.5332e+00,  ..., -4.4539e+14,
         -4.4539e+14, -4.4539e+14],
        [ 2.7972e+00,  2.8908e+00,  2.9841e+00,  ...,  8.8132e+14,
          8.8132e+14,  8.8132e+14],
        ...,
        [ 3.5552e+00,  3.4039e+00,  3.2494e+00,  ..., -9.8106e+14,
         -9.8106e+14, -9.8106e+14],
        [ 2.5858e+00,  2.5664e+00,  2.5465e+00,  ..., -7.6901e+14,
         -7.6901e+14, -7.6901e+14],
        [ 3.0963e+00,  3.1651e+00,  3.2336e+00,  ...,  1.3351e+15,
          1.3351e+15,  1.3351e+15]], device='cuda:0')
        target_xs = (target_xs.t_() * weights).t() # 
        poly_loss = mse(pred_xs[valid_xs], target_xs[valid_xs]) / valid_lanes_idx.sum() # tensor(0.0124, device='cuda:0')
        poly_loss = threshold(
            (pred_xs[valid_xs] - target_xs[valid_xs])**2).sum() / (valid_lanes_idx.sum() * valid_xs.sum())

        # applying weights to partial losses
        poly_loss = poly_loss * poly_weight
        lower_loss = lower_loss * lower_weight
        upper_loss = upper_loss * upper_weight
        cls_loss = cls_loss * cls_weight
        conf_loss = bce(pred_confs, target_confs) * conf_weight

        loss = conf_loss + lower_loss + upper_loss + poly_loss + cls_loss

        return loss, {
            'conf': conf_loss,
            'lower': lower_loss,
            'upper': upper_loss,
            'poly': poly_loss,
            'cls_loss': cls_loss
        } # 计算损失,文中有交代

附上述代码导入的模块lane_dataset.py

import cv2
import numpy as np
import imgaug.augmenters as iaa # 数据增强的库
from imgaug.augmenters import Resize
from torchvision.transforms import ToTensor
from torch.utils.data.dataset import Dataset
from imgaug.augmentables.lines import LineString, LineStringsOnImage

from .elas import ELAS
from .llamas import LLAMAS
from .tusimple import TuSimple
from .nolabel_dataset import NoLabelDataset

GT_COLOR = (255, 0, 0)
PRED_HIT_COLOR = (0, 255, 0)
PRED_MISS_COLOR = (0, 0, 255)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])


class LaneDataset(Dataset):
    def __init__(self,
                 dataset='tusimple', # 'tusimple'
                 augmentations=None, # []
                 normalize=False, # True
                 split='train', # 'val'
                 img_size=(360, 640), # [360, 640]
                 aug_chance=1., # 1.0
                 **kwargs): # {'max_lanes': 5, 'root': '/home/wqf/PolyLaneNet/PolyLaneNet-master'}
        super(LaneDataset, self).__init__() # 调用父方法初始化
        if dataset == 'tusimple':
            self.dataset = TuSimple(split=split, **kwargs)
        elif dataset == 'llamas':
            self.dataset = LLAMAS(split=split, **kwargs)
        elif dataset == 'elas':
            self.dataset = ELAS(split=split, **kwargs)
        elif dataset == 'nolabel_dataset':
            self.dataset = NoLabelDataset(**kwargs)
        else:
            raise NotImplementedError()

        self.transform_annotations() # 调用自己的方法
        self.img_h, self.img_w = img_size # [360, 640]

        if augmentations is not None: # 这里的augmentations=[]。因此会进入if语句,但不会执行任何东西。
            # add augmentations
            augmentations = [getattr(iaa, aug['name'])(**aug['parameters']) 
                             for aug in augmentations]  # add augmentation

        self.normalize = normalize # True
        transformations = iaa.Sequential([Resize({'height': self.img_h, 'width': self.img_w})]) # 创建一个增强器。Resize到固定尺寸。
        self.to_tensor = ToTensor() # 实例化一个类
        self.transform = iaa.Sequential([iaa.Sometimes(then_list=augmentations, p=aug_chance), transformations]) # https://blog.csdn.net/u012897374/article/details/80142744。这里augmentations=[];aug_chance=1;相当于没有进行数据增强
        self.max_lanes = self.dataset.max_lanes

    def transform_annotation(self, anno, img_wh=None):  # 这个函数好像是对标注信息进行转换。即由于图片尺寸的变化,标注点的坐标也要跟着转换
    # 一个anno例子如下:
    # {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492626287507231547/20.jpg', 'org_path': 'clips/0531/1492626287507231547/20.jpg', 'org_lanes': [[-2, -2, -2, -2, -2, -2, -2, -2, 499, 484, 470, 453, 435, 418, 400, 374, 346, 318, 290, 262, 235, 207, 179, 151, 123, 96, 68, 40, 12, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, -2, 529, 531, 533, 536, 538, 540, 540, 538, 536, 531, 525, 519, 513, 507, 499, 491, 483, 475, 467, 459, 451, 443, 435, 426, 418, 410, 402, 394, 386, 378, 370, 362, 354, 346, 338, 330, 322, 314, 306, 297, 289, 281, 273, 265, 257, 249, 241], [-2, -2, -2, -2, -2, -2, -2, -2, 553, 568, 583, 598, 613, 640, 667, 693, 719, 740, 761, 783, 804, 825, 846, 868, 883, 897, 912, 926, 941, 955, 969, 984, 998, 1013, 1027, 1042, 1056, 1070, 1085, 1099, 1114, 1128, 1143, 1157, 1171, 1186, 1200, 1215, 1229, 1244, 1258, 1272, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, 558, 585, 613, 646, 679, 714, 770, 817, 865, 912, 954, 994, 1033, 1073, 1113, 1153, 1193, 1232, 1272, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], 'lanes': [[(499, 240), (484, 250), (470, 260), (453, 270), (435, 280), (418, 290), (400, 300), (374, 310), (346, 320), (318, 330), (290, 340), (262, 350), (235, 360), (207, 370), (179, 380), (151, 390), (123, 400), (96, 410), (68, 420), (40, 430), (12, 440)], [(529, 250), (531, 260), (533, 270), (536, 280), (538, 290), (540, 300), (540, 310), (538, 320), (536, 330), (531, 340), (525, 350), (519, 360), (513, 370), (507, 380), (499, 390), (491, 400), (483, 410), (475, 420), (467, 430), (459, 440), (451, 450), (443, 460), (435, 470), (426, 480), (418, 490), (410, 500), (402, 510), (394, 520), (386, 530), (378, 540), (370, 550), (362, 560), (354, 570), (346, 580), (338, 590), (330, 600), (322, 610), (314, 620), (306, 630), (297, 640), (289, 650), (281, 660), (273, 670), (265, 680), (257, 690), (249, 700), (241, 710)], [(553, 240), (568, 250), (583, 260), (598, 270), (613, 280), (640, 290), (667, 300), (693, 310), (719, 320), (740, 330), (761, 340), (783, 350), (804, 360), (825, 370), (846, 380), (868, 390), (883, 400), (897, 410), (912, 420), (926, 430), (941, 440), (955, 450), (969, 460), (984, 470), (998, 480), (1013, 490), (1027, 500), (1042, 510), (1056, 520), (1070, 530), (1085, 540), (1099, 550), (1114, 560), (1128, 570), (1143, 580), (1157, 590), (1171, 600), (1186, 610), (1200, 620), (1215, 630), (1229, 640), (1244, 650), (1258, 660), (1272, 670)], [(558, 240), (585, 250), (613, 260), (646, 270), (679, 280), (714, 290), (770, 300), (817, 310), (865, 320), (912, 330), (954, 340), (994, 350), (1033, 360), (1073, 370), (1113, 380), (1153, 390), (1193, 400), (1232, 410), (1272, 420)]], 'aug': False, 'y_samples': [160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710]}
        if img_wh is None:
            img_h = self.dataset.get_img_heigth(anno['path']) # 720
            img_w = self.dataset.get_img_width(anno['path']) # 1280
        else:
            img_w, img_h = img_wh

        old_lanes = anno['lanes'] # 有效车道线坐标
        categories = anno['categories'] if 'categories' in anno else [1] * len(old_lanes) # 标注的是车道线的类别。
        old_lanes = zip(old_lanes, categories) # 一个zip object
        old_lanes = filter(lambda x: len(x[0]) > 0, old_lanes) # https://www.runoob.com/python3/python3-func-filter.html。filter函数返回一个迭代器对象。filter object
        lanes = np.ones((self.dataset.max_lanes, 1 + 2 + 2 * self.dataset.max_points), dtype=np.float32) * -1e5 # 对于第一个anno。self.dataset.max_lanes=5;self.dataset.max_points=56,所以lanes={ndarray:{5,115}}。值全是-1e5。lanes记录的是转换后的车道线点坐标。
        # 2*self.dataset.max_points):x、y坐标
        # 2:车道线上限、下限
        # 1:图像中这个位置有车道线(1)
        
        lanes[:, 0] = 0 # 所有行,第一列置为0
        old_lanes = sorted(old_lanes, key=lambda x: x[0][0][0]) # 内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。
        for lane_pos, (lane, category) in enumerate(old_lanes): 
            lower, upper = lane[0][1], lane[-1][1] # 最低、最高纵坐标
            xs = np.array([p[0] for p in lane]) / img_w # x坐标归一化。[0.38984375 0.378125   0.3671875  0.35390625 0.33984375 0.3265625, 0.3125     0.2921875  0.2703125  0.2484375  0.2265625  0.2046875, 0.18359375 0.16171875 0.13984375 0.11796875 0.09609375 0.075, 0.053125   0.03125    0.009375  ]
            ys = np.array([p[1] for p in lane]) / img_h # y坐标归一化。[0.33333333 0.34722222 0.36111111 0.375      0.38888889 0.40277778, 0.41666667 0.43055556 0.44444444 0.45833333 0.47222222 0.48611111, 0.5        0.51388889 0.52777778 0.54166667 0.55555556 0.56944444, 0.58333333 0.59722222 0.61111111]
            lanes[lane_pos, 0] = category # 第0列存储单条类别(为1即代表有车道线)
            lanes[lane_pos, 1] = lower / img_h # 第1列存储车道线下限
            lanes[lane_pos, 2] = upper / img_h # 第2列存储车道线上限
            lanes[lane_pos, 3:3 + len(xs)] = xs # 第3列到self.dataset.max_points之间存储归一化后的x坐标(如果没有这么长则保持为-1e5
            lanes[lane_pos, (3 + self.dataset.max_points):(3 + self.dataset.max_points + len(ys))] = ys # 第(3 + self.dataset.max_points)列到后面

        new_anno = { # 新的标注信息
            'path': anno['path'],
            'label': lanes,
            'old_anno': anno,
            'categories': [cat for _, cat in old_lanes] # 一个例子:'categories': [1, 1, 1, 1]
        }
        

        return new_anno

    @property
    def annotations(self):
        return self.dataset.annotations

    def transform_annotations(self):
        print('Transforming annotations...')
        self.dataset.annotations = np.array(list(map(self.transform_annotation, self.dataset.annotations))) # 关于map函数的用法。https://www.runoob.com/python/python-func-map.html。直接看例子就行。相当于把self.dataset.annotations逐个传入self.transform_annotation函数。只用map则生成一个迭代器,list(map())则生成列表。
        print('Done.')

    def draw_annotation(self, idx, pred=None, img=None, cls_pred=None):
        if img is None:
            img, label, _ = self.__getitem__(idx, transform=True)
            # Tensor to opencv image
            img = img.permute(1, 2, 0).numpy()
            # Unnormalize
            if self.normalize:
                img = img * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
            img = (img * 255).astype(np.uint8)
        else:
            _, label, _ = self.__getitem__(idx)

        img_h, img_w, _ = img.shape

        # Draw label
        for i, lane in enumerate(label):
            if lane[0] == 0:  # Skip invalid lanes
                continue
            lane = lane[3:]  # remove conf, upper and lower positions
            xs = lane[:len(lane) // 2]
            ys = lane[len(lane) // 2:]
            ys = ys[xs >= 0]
            xs = xs[xs >= 0]

            # draw GT points
            for p in zip(xs, ys):
                p = (int(p[0] * img_w), int(p[1] * img_h))
                img = cv2.circle(img, p, 5, color=GT_COLOR, thickness=-1)

            # draw GT lane ID
            cv2.putText(img,
                        str(i), (int(xs[0] * img_w), int(ys[0] * img_h)),
                        fontFace=cv2.FONT_HERSHEY_COMPLEX,
                        fontScale=1,
                        color=(0, 255, 0))

        if pred is None:
            return img

        # Draw predictions
        pred = pred[pred[:, 0] != 0]  # filter invalid lanes
        matches, accs, _ = self.dataset.get_metrics(pred, idx)
        overlay = img.copy()
        for i, lane in enumerate(pred):
            if matches[i]:
                color = PRED_HIT_COLOR
            else:
                color = PRED_MISS_COLOR
            lane = lane[1:]  # remove conf
            lower, upper = lane[0], lane[1]
            lane = lane[2:]  # remove upper, lower positions

            # generate points from the polynomial
            ys = np.linspace(lower, upper, num=100)
            points = np.zeros((len(ys), 2), dtype=np.int32)
            points[:, 1] = (ys * img_h).astype(int)
            points[:, 0] = (np.polyval(lane, ys) * img_w).astype(int)
            points = points[(points[:, 0] > 0) & (points[:, 0] < img_w)]

            # draw lane with a polyline on the overlay
            for current_point, next_point in zip(points[:-1], points[1:]):
                overlay = cv2.line(overlay, tuple(current_point), tuple(next_point), color=color, thickness=2)

            # draw class icon
            if cls_pred is not None and len(points) > 0:
                class_icon = self.dataset.get_class_icon(cls_pred[i])
                class_icon = cv2.resize(class_icon, (32, 32))
                mid = tuple(points[len(points) // 2] - 60)
                x, y = mid

                img[y:y + class_icon.shape[0], x:x + class_icon.shape[1]] = class_icon

            # draw lane ID
            if len(points) > 0:
                cv2.putText(img, str(i), tuple(points[0]), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=color)

            # draw lane accuracy
            if len(points) > 0:
                cv2.putText(img,
                            '{:.2f}'.format(accs[i] * 100),
                            tuple(points[len(points) // 2] - 30),
                            fontFace=cv2.FONT_HERSHEY_COMPLEX,
                            fontScale=.75,
                            color=color)
        # Add lanes overlay
        w = 0.6
        img = ((1. - w) * img + w * overlay).astype(np.uint8)

        return img

    def lane_to_linestrings(self, lanes):
        lines = []
        for lane in lanes:
            lines.append(LineString(lane))

        return lines

    def linestrings_to_lanes(self, lines):
        lanes = []
        for line in lines:
            lanes.append(line.coords)

        return lanes

    def __getitem__(self, idx, transform=True): # self:
    # idx:112
    # self.dataset:{Tusimple:358}.
    # 这个方法相当于去取值
    # 至于idx,则是经常这么用的。看的代码还是少吧

        item = self.dataset[idx] # 按照这种方法去取值,就会调用Tusimple类中__getitem__方法。
        # 一个item例子如下:
        # {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492630158697138522/20.jpg', 'label': array([[ 1.0000000e+00,  2.7777779e-01,  9.7222221e-01,  5.3125000e-01,
         5.2265626e-01,  5.1406252e-01,  5.1093751e-01,  5.1015627e-01,
         5.0781250e-01,  5.0078124e-01,  4.9453124e-01,  4.8828125e-01,
         4.8203126e-01,  4.7499999e-01,  4.6875000e-01,  4.6250001e-01,
         4.5625001e-01,  4.4921875e-01,  4.4296876e-01,  4.3671876e-01,
         4.2968750e-01,  4.2343751e-01,  4.1718751e-01,  4.1093749e-01,
         4.0390626e-01,  3.9765626e-01,  3.9140624e-01,  3.8515624e-01,
         3.7812501e-01,  3.7187499e-01,  3.6562499e-01,  3.5937500e-01,
         3.5234374e-01,  3.4609374e-01,  3.3984375e-01,  3.3359376e-01,
         3.2656249e-01,  3.2031250e-01,  3.1406251e-01,  3.0781251e-01,
         3.0078125e-01,  2.9453126e-01,  2.8828126e-01,  2.8203124e-01,
         2.7500001e-01,  2.6875001e-01,  2.6249999e-01,  2.5546876e-01,
         2.4921875e-01,  2.4296875e-01,  2.3671874e-01,  2.2968750e-01,
         2.2343750e-01,  2.1718749e-01, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,  2.7777779e-01,
         2.9166666e-01,  3.0555555e-01,  3.1944445e-01,  3.3333334e-01,
         3.4722221e-01,  3.6111110e-01,  3.7500000e-01,  3.8888890e-01,
         4.0277779e-01,  4.1666666e-01,  4.3055555e-01,  4.4444445e-01,
         4.5833334e-01,  4.7222221e-01,  4.8611110e-01,  5.0000000e-01,
         5.1388890e-01,  5.2777779e-01,  5.4166669e-01,  5.5555558e-01,
         5.6944442e-01,  5.8333331e-01,  5.9722221e-01,  6.1111110e-01,
         6.2500000e-01,  6.3888890e-01,  6.5277779e-01,  6.6666669e-01,
         6.8055558e-01,  6.9444442e-01,  7.0833331e-01,  7.2222221e-01,
         7.3611110e-01,  7.5000000e-01,  7.6388890e-01,  7.7777779e-01,
         7.9166669e-01,  8.0555558e-01,  8.1944442e-01,  8.3333331e-01,
         8.4722221e-01,  8.6111110e-01,  8.7500000e-01,  8.8888890e-01,
         9.0277779e-01,  9.1666669e-01,  9.3055558e-01,  9.4444442e-01,
         9.5833331e-01,  9.7222221e-01, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
       [ 1.0000000e+00,  2.7777779e-01,  9.1666669e-01,  5.4140627e-01,
         5.3671873e-01,  5.3281248e-01,  5.3046876e-01,  5.3437501e-01,
         5.4062498e-01,  5.4765624e-01,  5.5859375e-01,  5.6953126e-01,
         5.8125001e-01,  5.9218752e-01,  6.0312498e-01,  6.1406249e-01,
         6.2578124e-01,  6.3671875e-01,  6.4765626e-01,  6.5859377e-01,
         6.7031252e-01,  6.8124998e-01,  6.9218749e-01,  7.0390624e-01,
         7.1484375e-01,  7.2578126e-01,  7.3671877e-01,  7.4843752e-01,
         7.5937498e-01,  7.7031249e-01,  7.8203124e-01,  7.9296875e-01,
         8.0390626e-01,  8.1484377e-01,  8.2656252e-01,  8.3749998e-01,
         8.4843749e-01,  8.6015624e-01,  8.7109375e-01,  8.8203126e-01,
         8.9296877e-01,  9.0468752e-01,  9.1562498e-01,  9.2656249e-01,
         9.3828124e-01,  9.4921875e-01,  9.6015626e-01,  9.7109377e-01,
         9.8281252e-01,  9.9374998e-01, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,  2.7777779e-01,
         2.9166666e-01,  3.0555555e-01,  3.1944445e-01,  3.3333334e-01,
         3.4722221e-01,  3.6111110e-01,  3.7500000e-01,  3.8888890e-01,
         4.0277779e-01,  4.1666666e-01,  4.3055555e-01,  4.4444445e-01,
         4.5833334e-01,  4.7222221e-01,  4.8611110e-01,  5.0000000e-01,
         5.1388890e-01,  5.2777779e-01,  5.4166669e-01,  5.5555558e-01,
         5.6944442e-01,  5.8333331e-01,  5.9722221e-01,  6.1111110e-01,
         6.2500000e-01,  6.3888890e-01,  6.5277779e-01,  6.6666669e-01,
         6.8055558e-01,  6.9444442e-01,  7.0833331e-01,  7.2222221e-01,
         7.3611110e-01,  7.5000000e-01,  7.6388890e-01,  7.7777779e-01,
         7.9166669e-01,  8.0555558e-01,  8.1944442e-01,  8.3333331e-01,
         8.4722221e-01,  8.6111110e-01,  8.7500000e-01,  8.8888890e-01,
         9.0277779e-01,  9.1666669e-01, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
       [ 1.0000000e+00,  2.7777779e-01,  5.4166669e-01,  5.5546874e-01,
         5.5312502e-01,  5.5156249e-01,  5.5078125e-01,  5.6015623e-01,
         5.6953126e-01,  5.8906251e-01,  6.1171877e-01,  6.3437498e-01,
         6.6640627e-01,  6.9765627e-01,  7.2890627e-01,  7.6093751e-01,
         7.9218751e-01,  8.2343751e-01,  8.5468751e-01,  8.8671875e-01,
         9.1796875e-01,  9.4921875e-01,  9.8124999e-01, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,  2.7777779e-01,
         2.9166666e-01,  3.0555555e-01,  3.1944445e-01,  3.3333334e-01,
         3.4722221e-01,  3.6111110e-01,  3.7500000e-01,  3.8888890e-01,
         4.0277779e-01,  4.1666666e-01,  4.3055555e-01,  4.4444445e-01,
         4.5833334e-01,  4.7222221e-01,  4.8611110e-01,  5.0000000e-01,
         5.1388890e-01,  5.2777779e-01,  5.4166669e-01, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
       [ 0.0000000e+00, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
       [ 0.0000000e+00, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
        -1.0000000e+05, -1.0000000e+05, -1.0000000e+05]], dtype=float32), 'old_anno': {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492630158697138522/20.jpg', 'org_path': 'clips/0531/1492630158697138522/20.jpg', 'org_lanes': [[-2, -2, -2, -2, 680, 669, 658, 654, 653, 650, 641, 633, 625, 617, 608, 600, 592, 584, 575, 567, 559, 550, 542, 534, 526, 517, 509, 501, 493, 484, 476, 468, 460, 451, 443, 435, 427, 418, 410, 402, 394, 385, 377, 369, 361, 352, 344, 336, 327, 319, 311, 303, 294, 286, 278, -2], [-2, -2, -2, -2, 693, 687, 682, 679, 684, 692, 701, 715, 729, 744, 758, 772, 786, 801, 815, 829, 843, 858, 872, 886, 901, 915, 929, 943, 958, 972, 986, 1001, 1015, 1029, 1043, 1058, 1072, 1086, 1101, 1115, 1129, 1143, 1158, 1172, 1186, 1201, 1215, 1229, 1243, 1258, 1272, -2, -2, -2, -2, -2], [-2, -2, -2, -2, 711, 708, 706, 705, 717, 729, 754, 783, 812, 853, 893, 933, 974, 1014, 1054, 1094, 1135, 1175, 1215, 1256, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], 'lanes': [[(680, 200), (669, 210), (658, 220), (654, 230), (653, 240), (650, 250), (641, 260), (633, 270), (625, 280), (617, 290), (608, 300), (600, 310), (592, 320), (584, 330), (575, 340), (567, 350), (559, 360), (550, 370), (542, 380), (534, 390), (526, 400), (517, 410), (509, 420), (501, 430), (493, 440), (484, 450), (476, 460), (468, 470), (460, 480), (451, 490), (443, 500), (435, 510), (427, 520), (418, 530), (410, 540), (402, 550), (394, 560), (385, 570), (377, 580), (369, 590), (361, 600), (352, 610), (344, 620), (336, 630), (327, 640), (319, 650), (311, 660), (303, 670), (294, 680), (286, 690), (278, 700)], [(693, 200), (687, 210), (682, 220), (679, 230), (684, 240), (692, 250), (701, 260), (715, 270), (729, 280), (744, 290), (758, 300), (772, 310), (786, 320), (801, 330), (815, 340), (829, 350), (843, 360), (858, 370), (872, 380), (886, 390), (901, 400), (915, 410), (929, 420), (943, 430), (958, 440), (972, 450), (986, 460), (1001, 470), (1015, 480), (1029, 490), (1043, 500), (1058, 510), (1072, 520), (1086, 530), (1101, 540), (1115, 550), (1129, 560), (1143, 570), (1158, 580), (1172, 590), (1186, 600), (1201, 610), (1215, 620), (1229, 630), (1243, 640), (1258, 650), (1272, 660)], [(711, 200), (708, 210), (706, 220), (705, 230), (717, 240), (729, 250), (754, 260), (783, 270), (812, 280), (853, 290), (893, 300), (933, 310), (974, 320), (1014, 330), (1054, 340), (1094, 350), (1135, 360), (1175, 370), (1215, 380), (1256, 390)]], 'aug': False, 'y_samples': [160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710]}, 'categories': [1, 1, 1]}
        
        img = cv2.imread(item['path']) # 读取了图片{ndarray:(720,1280,3)}
        label = item['label'] # {ndarray:(5,115)}
        if transform:
            line_strings = self.lane_to_linestrings(item['old_anno']['lanes']) # 传入原始标注的车道线的点。将数据格式由{list:3}转换为{list:3}。不同之处在于每一个元素转换为了Linestring类型。
            # 一个例子如下:[LineString([(680.00, 200.00), (669.00, 210.00), (658.00, 220.00), (654.00, 230.00), (653.00, 240.00), (650.00, 250.00), (641.00, 260.00), (633.00, 270.00), (625.00, 280.00), (617.00, 290.00), (608.00, 300.00), (600.00, 310.00), (592.00, 320.00), (584.00, 330.00), (575.00, 340.00), (567.00, 350.00), (559.00, 360.00), (550.00, 370.00), (542.00, 380.00), (534.00, 390.00), (526.00, 400.00), (517.00, 410.00), (509.00, 420.00), (501.00, 430.00), (493.00, 440.00), (484.00, 450.00), (476.00, 460.00), (468.00, 470.00), (460.00, 480.00), (451.00, 490.00), (443.00, 500.00), (435.00, 510.00), (427.00, 520.00), (418.00, 530.00), (410.00, 540.00), (402.00, 550.00), (394.00, 560.00), (385.00, 570.00), (377.00, 580.00), (369.00, 590.00), (361.00, 600.00), (352.00, 610.00), (344.00, 620.00), (336.00, 630.00), (327.00, 640.00), (319.00, 650.00), (311.00, 660.00), (303.00, 670.00), (294.00, 680.00), (286.00, 690.00), (278.00, 700.00)], label=None), LineString([(693.00, 200.00), (687.00, 210.00), (682.00, 220.00), (679.00, 230.00), (684.00, 240.00), (692.00, 250.00), (701.00, 260.00), (715.00, 270.00), (729.00, 280.00), (744.00, 290.00), (758.00, 300.00), (772.00, 310.00), (786.00, 320.00), (801.00, 330.00), (815.00, 340.00), (829.00, 350.00), (843.00, 360.00), (858.00, 370.00), (872.00, 380.00), (886.00, 390.00), (901.00, 400.00), (915.00, 410.00), (929.00, 420.00), (943.00, 430.00), (958.00, 440.00), (972.00, 450.00), (986.00, 460.00), (1001.00, 470.00), (1015.00, 480.00), (1029.00, 490.00), (1043.00, 500.00), (1058.00, 510.00), (1072.00, 520.00), (1086.00, 530.00), (1101.00, 540.00), (1115.00, 550.00), (1129.00, 560.00), (1143.00, 570.00), (1158.00, 580.00), (1172.00, 590.00), (1186.00, 600.00), (1201.00, 610.00), (1215.00, 620.00), (1229.00, 630.00), (1243.00, 640.00), (1258.00, 650.00), (1272.00, 660.00)], label=None), LineString([(711.00, 200.00), (708.00, 210.00), (706.00, 220.00), (705.00, 230.00), (717.00, 240.00), (729.00, 250.00), (754.00, 260.00), (783.00, 270.00), (812.00, 280.00), (853.00, 290.00), (893.00, 300.00), (933.00, 310.00), (974.00, 320.00), (1014.00, 330.00), (1054.00, 340.00), (1094.00, 350.00), (1135.00, 360.00), (1175.00, 370.00), (1215.00, 380.00), (1256.00, 390.00)], label=None)]
            
            line_strings = LineStringsOnImage(line_strings, shape=img.shape) # 再次进行类型转换。
            # 类型:{LineStringsOnImage:3}。一个例子如下:
            # LineStringsOnImage([LineString([(680.00, 200.00), (669.00, 210.00), (658.00, 220.00), (654.00, 230.00), (653.00, 240.00), (650.00, 250.00), (641.00, 260.00), (633.00, 270.00), (625.00, 280.00), (617.00, 290.00), (608.00, 300.00), (600.00, 310.00), (592.00, 320.00), (584.00, 330.00), (575.00, 340.00), (567.00, 350.00), (559.00, 360.00), (550.00, 370.00), (542.00, 380.00), (534.00, 390.00), (526.00, 400.00), (517.00, 410.00), (509.00, 420.00), (501.00, 430.00), (493.00, 440.00), (484.00, 450.00), (476.00, 460.00), (468.00, 470.00), (460.00, 480.00), (451.00, 490.00), (443.00, 500.00), (435.00, 510.00), (427.00, 520.00), (418.00, 530.00), (410.00, 540.00), (402.00, 550.00), (394.00, 560.00), (385.00, 570.00), (377.00, 580.00), (369.00, 590.00), (361.00, 600.00), (352.00, 610.00), (344.00, 620.00), (336.00, 630.00), (327.00, 640.00), (319.00, 650.00), (311.00, 660.00), (303.00, 670.00), (294.00, 680.00), (286.00, 690.00), (278.00, 700.00)], label=None), LineString([(693.00, 200.00), (687.00, 210.00), (682.00, 220.00), (679.00, 230.00), (684.00, 240.00), (692.00, 250.00), (701.00, 260.00), (715.00, 270.00), (729.00, 280.00), (744.00, 290.00), (758.00, 300.00), (772.00, 310.00), (786.00, 320.00), (801.00, 330.00), (815.00, 340.00), (829.00, 350.00), (843.00, 360.00), (858.00, 370.00), (872.00, 380.00), (886.00, 390.00), (901.00, 400.00), (915.00, 410.00), (929.00, 420.00), (943.00, 430.00), (958.00, 440.00), (972.00, 450.00), (986.00, 460.00), (1001.00, 470.00), (1015.00, 480.00), (1029.00, 490.00), (1043.00, 500.00), (1058.00, 510.00), (1072.00, 520.00), (1086.00, 530.00), (1101.00, 540.00), (1115.00, 550.00), (1129.00, 560.00), (1143.00, 570.00), (1158.00, 580.00), (1172.00, 590.00), (1186.00, 600.00), (1201.00, 610.00), (1215.00, 620.00), (1229.00, 630.00), (1243.00, 640.00), (1258.00, 650.00), (1272.00, 660.00)], label=None), LineString([(711.00, 200.00), (708.00, 210.00), (706.00, 220.00), (705.00, 230.00), (717.00, 240.00), (729.00, 250.00), (754.00, 260.00), (783.00, 270.00), (812.00, 280.00), (853.00, 290.00), (893.00, 300.00), (933.00, 310.00), (974.00, 320.00), (1014.00, 330.00), (1054.00, 340.00), (1094.00, 350.00), (1135.00, 360.00), (1175.00, 370.00), (1215.00, 380.00), (1256.00, 390.00)], label=None)], shape=(720, 1280, 3))
            img, line_strings = self.transform(image=img, line_strings=line_strings) # 将点和图像都进行格式转换。
            # img:{ndarray:(720,1280,3)} 到 {ndarray:(360,640,3)}
            # line_strings:由上述到LineStringsOnImage([LineString([(340.00, 100.00), (334.50, 105.00), (329.00, 110.00), (327.00, 115.00), (326.50, 120.00), (325.00, 125.00), (320.50, 130.00), (316.50, 135.00), (312.50, 140.00), (308.50, 145.00), (304.00, 150.00), (300.00, 155.00), (296.00, 160.00), (292.00, 165.00), (287.50, 170.00), (283.50, 175.00), (279.50, 180.00), (275.00, 185.00), (271.00, 190.00), (267.00, 195.00), (263.00, 200.00), (258.50, 205.00), (254.50, 210.00), (250.50, 215.00), (246.50, 220.00), (242.00, 225.00), (238.00, 230.00), (234.00, 235.00), (230.00, 240.00), (225.50, 245.00), (221.50, 250.00), (217.50, 255.00), (213.50, 260.00), (209.00, 265.00), (205.00, 270.00), (201.00, 275.00), (197.00, 280.00), (192.50, 285.00), (188.50, 290.00), (184.50, 295.00), (180.50, 300.00), (176.00, 305.00), (172.00, 310.00), (168.00, 315.00), (163.50, 320.00), (159.50, 325.00), (155.50, 330.00), (151.50, 335.00), (147.00, 340.00), (143.00, 345.00), (139.00, 350.00)], label=None), LineString([(346.50, 100.00), (343.50, 105.00), (341.00, 110.00), (339.50, 115.00), (342.00, 120.00), (346.00, 125.00), (350.50, 130.00), (357.50, 135.00), (364.50, 140.00), (372.00, 145.00), (379.00, 150.00), (386.00, 155.00), (393.00, 160.00), (400.50, 165.00), (407.50, 170.00), (414.50, 175.00), (421.50, 180.00), (429.00, 185.00), (436.00, 190.00), (443.00, 195.00), (450.50, 200.00), (457.50, 205.00), (464.50, 210.00), (471.50, 215.00), (479.00, 220.00), (486.00, 225.00), (493.00, 230.00), (500.50, 235.00), (507.50, 240.00), (514.50, 245.00), (521.50, 250.00), (529.00, 255.00), (536.00, 260.00), (543.00, 265.00), (550.50, 270.00), (557.50, 275.00), (564.50, 280.00), (571.50, 285.00), (579.00, 290.00), (586.00, 295.00), (593.00, 300.00), (600.50, 305.00), (607.50, 310.00), (614.50, 315.00), (621.50, 320.00), (629.00, 325.00), (636.00, 330.00)], label=None), LineString([(355.50, 100.00), (354.00, 105.00), (353.00, 110.00), (352.50, 115.00), (358.50, 120.00), (364.50, 125.00), (377.00, 130.00), (391.50, 135.00), (406.00, 140.00), (426.50, 145.00), (446.50, 150.00), (466.50, 155.00), (487.00, 160.00), (507.00, 165.00), (527.00, 170.00), (547.00, 175.00), (567.50, 180.00), (587.50, 185.00), (607.50, 190.00), (628.00, 195.00)], label=None)], shape=(360, 640, 3))
            line_strings.clip_out_of_image_() # Clip off all parts of the LSs that are outside of an image in-place.
            # 剪掉多余部分
            new_anno = {'path': item['path'], 'lanes': self.linestrings_to_lanes(line_strings)} # 一个例子如下
            # {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492630158697138522/20.jpg', 'lanes': [array([[340.     , 100.00001],
       [334.5    , 105.     ],
       [329.     , 110.     ],
       [327.     , 115.     ],
       [326.5    , 120.     ],
       [325.     , 124.99999],
       [320.5    , 130.     ],
       [316.5    , 135.     ],
       [312.5    , 140.     ],
       [308.5    , 145.     ],
       [304.     , 150.     ],
       [300.     , 155.     ],
       [296.     , 160.     ],
       [292.     , 165.     ],
       [287.5    , 170.     ],
       [283.5    , 175.     ],
       [279.5    , 180.     ],
       [275.     , 185.     ],
       [271.     , 190.     ],
       [267.     , 195.     ],
       [263.     , 200.00002],
       [258.5    , 204.99998],
       [254.5    , 210.     ],
       [250.5    , 215.     ],
       [246.5    , 220.     ],
       [242.     , 225.     ],
       [238.     , 230.     ],
       [234.     , 235.     ],
       [230.     , 240.     ],
       [225.5    , 245.00002],
       [221.5    , 249.99998],
       [217.5    , 255.     ],
       [213.5    , 260.     ],
       [209.     , 265.     ],
       [205.     , 270.     ],
       [201.     , 275.     ],
       [197.     , 280.     ],
       [192.5    , 285.     ],
       [188.5    , 290.     ],
       [184.5    , 295.     ],
       [180.5    , 300.     ],
       [176.     , 305.     ],
       [172.     , 310.     ],
       [168.     , 315.     ],
       [163.5    , 320.     ],
       [159.5    , 325.     ],
       [155.5    , 330.     ],
       [151.5    , 335.     ],
       [147.     , 340.     ],
       [143.     , 345.     ],
       [139.     , 350.     ]], dtype=float32), array([[346.5    , 100.00001],
       [343.5    , 105.     ],
       [341.     , 110.     ],
       [339.5    , 115.     ],
       [342.     , 120.     ],
       [346.     , 124.99999],
       [350.5    , 130.     ],
       [357.5    , 135.     ],
       [364.5    , 140.     ],
       [372.     , 145.     ],
       [379.     , 150.     ],
       [386.     , 155.     ],
       [393.     , 160.     ],
       [400.5    , 165.     ],
       [407.5    , 170.     ],
       [414.5    , 175.     ],
       [421.5    , 180.     ],
       [429.     , 185.     ],
       [436.     , 190.     ],
       [443.     , 195.     ],
       [450.5    , 200.00002],
       [457.5    , 204.99998],
       [464.5    , 210.     ],
       [471.5    , 215.     ],
       [479.     , 220.     ],
       [486.     , 225.     ],
       [493.     , 230.     ],
       [500.5    , 235.     ],
       [507.5    , 240.     ],
       [514.5    , 245.00002],
       [521.5    , 249.99998],
       [529.     , 255.     ],
       [536.     , 260.     ],
       [543.     , 265.     ],
       [550.5    , 270.     ],
       [557.5    , 275.     ],
       [564.5    , 280.     ],
       [571.5    , 285.     ],
       [579.     , 290.     ],
       [586.     , 295.     ],
       [593.     , 300.     ],
       [600.5    , 305.     ],
       [607.5    , 310.     ],
       [614.5    , 315.     ],
       [621.5    , 320.     ],
       [629.     , 325.     ],
       [636.     , 330.     ]], dtype=float32), array([[355.5    , 100.00001],
       [354.     , 105.     ],
       [353.     , 110.     ],
       [352.5    , 115.     ],
       [358.5    , 120.     ],
       [364.5    , 124.99999],
       [377.     , 130.     ],
       [391.5    , 135.     ],
       [406.     , 140.     ],
       [426.5    , 145.     ],
       [446.5    , 150.     ],
       [466.5    , 155.     ],
       [487.     , 160.     ],
       [507.     , 165.     ],
       [527.     , 170.     ],
       [547.     , 175.     ],
       [567.5    , 180.     ],
       [587.5    , 185.     ],
       [607.5    , 190.     ],
       [628.     , 195.     ]], dtype=float32)]}
            new_anno['categories'] = item['categories'] # 
            label = self.transform_annotation(new_anno, img_wh=(self.img_w, self.img_h))['label'] # 转换标签。这里再次调用self.transform_annotation函数,之前是为了将图片由原始尺寸进行归一化到[0,1],现在时转换为目标尺寸[640,360]

        img = img / 255.# {ndarray:(360,640,3)}
        if self.normalize:
            img = (img - IMAGENET_MEAN) / IMAGENET_STD # 数据处理,叫什么忘了。图像均值和标准差是文件开头给定的(应该是作者测好的)
        img = self.to_tensor(img.astype(np.float32))
        return (img, label, idx) # label:{ndarray:(5,115)}。 idx:图片的索引

    def __len__(self):
        return len(self.dataset) # 


def main():
    import torch
    from lib.config import Config
    np.random.seed(0)
    torch.manual_seed(0)
    cfg = Config('config.yaml')
    train_dataset = cfg.get_dataset('train')
    for idx in range(len(train_dataset)):
        img = train_dataset.draw_annotation(idx)
        cv2.imshow('sample', img)
        cv2.waitKey(0)


if __name__ == "__main__":
    main()

附上述代码导入的tusimple.py模块的解读:

import os
import json
import random

import numpy as np
from tabulate import tabulate

from utils.lane import LaneEval
from utils.metric import eval_json

SPLIT_FILES = {
    'train+val': ['label_data_0313.json', 'label_data_0601.json', 'label_data_0531.json'],
    'train': ['label_data_0313.json', 'label_data_0601.json'],
    'val': ['label_data_0531.json'],
    'test': ['test_label.json'],
}


class TuSimple(object):
    def __init__(self, split='train', max_lanes=None, root=None, metric='default'):
        self.split = split # 'val'
        self.root = root # '/home/wqf/PolyLaneNet/PolyLaneNet-master'
        self.metric = metric # 'default'

        if split not in SPLIT_FILES.keys():
            raise Exception('Split `{}` does not exist.'.format(split))

        self.anno_files = [os.path.join(self.root, path) for path in SPLIT_FILES[split]] # ['/home/wqf/PolyLaneNet/PolyLaneNet-master/label_data_0531.json']

        if root is None:
            raise Exception('Please specify the root directory')

        self.img_w, self.img_h = 1280, 720
        self.max_points = 0
        self.load_annotations() # 载入标注信息 

        # Force max_lanes, used when evaluating testing with models trained on other datasets
        if max_lanes is not None:
            self.max_lanes = max_lanes # 5

    def get_img_heigth(self, path):
        return 720

    def get_img_width(self, path):
        return 1280

    def get_metrics(self, lanes, idx):
        label = self.annotations[idx]
        org_anno = label['old_anno']
        pred = self.pred2lanes(org_anno['path'], lanes, org_anno['y_samples'])
        _, _, _, matches, accs, dist = LaneEval.bench(pred, org_anno['org_lanes'], org_anno['y_samples'], 0, True)

        return matches, accs, dist

    def pred2lanes(self, path, pred, y_samples):
        ys = np.array(y_samples) / self.img_h
        lanes = []
        for lane in pred: # lane:[ 1.          0.35354868  0.60016412 -0.21744023 -1.12229943 -0.06798992,  0.52022016]
        # 一张图片中的一条车道线
            if lane[0] == 0:
                continue
            lane_pred = np.polyval(lane[3:], ys) * self.img_w # 网上俄教程解释不怎么,直接看源码注释。
            # Evaluate a polynomial at specific values.
    # If `p` is of length N, this function returns the value:
   # ``p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1]``
   # 具体来说就是实现了系数*y的对应次数
            lane_pred[(ys < lane[1]) | (ys > lane[2])] = -2 # y比下限低或者比上限高的对应位置置为-2(将x)
            lanes.append(list(lane_pred))

        return lanes # 返回图森格式的车道线x坐标
        # 一个例子如下:
        # [[-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 434.0222707652424, 416.55553232878447, 398.41377182098415, 379.59251516298053, 360.087288275913, 339.89361708092054, 319.0070274991427, 297.42304545171834, 275.13719685978685, 252.14500764448744, 228.44200372695923, 204.0237110283416, 178.88565546977338, 153.0233629723943, 126.43235945734297, 99.10817084575925, 71.0463230587817, 42.24234201755017, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0], [-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 527.3455002368396, 525.8638570457697, 524.135191365361, 522.1609893227596, 519.9427370451115, 517.4819206595625, 514.7800262932587, 511.8385400733462, 508.65894812697087, 505.2427365812786, 501.5913915634155, 497.7063992005276, 493.58924561976085, 489.2414169482611, 484.66439931317444, 479.85967884164694, 474.82874166082456, 469.57307389785313, 464.09416167987877, 458.3934911340475, 452.4725483875052, 446.3328195673979, 439.9757908008717, 433.40294821507234, 426.615777937146, 419.61576609423867, 412.40439881349624, 404.9831622220647, 397.35354244709015, 389.51702561571847, 381.47509785509567, 373.22924529236775, 364.7809540546807, 356.13171026918053, 347.2830000630131, 338.2363095633246, 328.9931248972608, 319.55493219196796, 309.92321757459183, 300.0994671722784, 290.08516711217385, 279.881803521424, 269.4908625271749, 258.9138302565725, 248.15219283676277, 237.20743639489172], [-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 599.3538418015662, 617.9884305596352, 636.5290950322511, 654.9729896887666, 673.3172689985347, 691.5590874309075, 709.6955994552383, 727.7239595408793, 745.6413221571835, 763.4448417735034, 781.1316728591919, 798.6989698836015, 816.143887316085, 833.463579625995, 850.6552012826844, 867.7159067555056, 884.6428505138115, 901.4331870269547, 918.084070764288, 934.5926561951637, 950.956097788935, 967.1715500149544, 983.2361673425744, 999.1471042411481, 1014.9015151800279, 1030.4965546285664, 1045.9293770561164, 1061.1971369320308, 1076.2969887256622, 1091.2260869063632, 1105.9815859434864, 1120.5606403063846, 1134.9604044644102, 1149.1780328869165, 1163.210680043256, 1177.0555004027808, 1190.7096484348444, 1204.170278608799, 1217.4345453939973, 1230.4996032597924, 1243.3626066755367, 1256.0207101105825, 1268.4710680342832, -2.0, -2.0, -2.0], [-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 619.0527163349878, 663.2424223423004, 707.2017337919426, 750.9011041562089, 794.310986907394, 837.4018355177918, 880.1441034596971, 922.5082442054043, 964.4647112272077, 1005.9839579974018, 1047.0364379882812, 1087.5926046721404, 1127.6229115212732, 1167.0978120079744, 1205.987759604539, 1244.263207783261, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0]]
        
        

    def load_annotations(self): # 执行这个函数后,self.annotations变成{list:358}
        self.annotations = []
        max_lanes = 0
        for anno_file in self.anno_files:
            with open(anno_file, 'r') as anno_obj:
                lines = anno_obj.readlines() # 逐行读入。在这里,是{list:358}
            for line in lines:
                data = json.loads(line)
                y_samples = data['h_samples'] # {list:56}[160, 170, ……, 710]
                gt_lanes = data['lanes'] # {list:4}
                lanes = [[(x, y) for (x, y) in zip(lane, y_samples) if x >= 0] for lane in gt_lanes] # 每条车道线的(x,y)。这里一共有4条。{list:4}[[(499, 240), (484, 250)……, (1272, 420)]]
                lanes = [lane for lane in lanes if len(lane) > 0] # 剔除无效车道线。 4
                max_lanes = max(max_lanes, len(lanes)) # 4
                self.max_points = max(self.max_points, max([len(l) for l in gt_lanes])) # 56
                self.annotations.append({
                    'path': os.path.join(self.root, data['raw_file']), # 数据集路径
                    'org_path': data['raw_file'], # 原始图片路径
                    'org_lanes': gt_lanes, # 图森数据集中车道线原始x坐标
                    'lanes': lanes, # 有效车道线坐标(剔除-2等)
                    'aug': False,
                    'y_samples': y_samples # y坐标的值
                }) # 
                # 一个例子如下:[{'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492626287507231547/20.jpg', 'org_path': 'clips/0531/1492626287507231547/20.jpg', 'org_lanes': [[-2, -2, -2, -2, -2, -2, -2, -2, 499, 484, 470, 453, 435, 418, 400, 374, 346, 318, 290, 262, 235, 207, 179, 151, 123, 96, 68, 40, 12, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, -2, 529, 531, 533, 536, 538, 540, 540, 538, 536, 531, 525, 519, 513, 507, 499, 491, 483, 475, 467, 459, 451, 443, 435, 426, 418, 410, 402, 394, 386, 378, 370, 362, 354, 346, 338, 330, 322, 314, 306, 297, 289, 281, 273, 265, 257, 249, 241], [-2, -2, -2, -2, -2, -2, -2, -2, 553, 568, 583, 598, 613, 640, 667, 693, 719, 740, 761, 783, 804, 825, 846, 868, 883, 897, 912, 926, 941, 955, 969, 984, 998, 1013, 1027, 1042, 1056, 1070, 1085, 1099, 1114, 1128, 1143, 1157, 1171, 1186, 1200, 1215, 1229, 1244, 1258, 1272, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, 558, 585, 613, 646, 679, 714, 770, 817, 865, 912, 954, 994, 1033, 1073, 1113, 1153, 1193, 1232, 1272, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], 'lanes': [[(499, 240), (484, 250), (470, 260), (453, 270), (435, 280), (418, 290), (400, 300), (374, 310), (346, 320), (318, 330), (290, 340), (262, 350), (235, 360), (207, 370), (179, 380), (151, 390), (123, 400), (96, 410), (68, 420), (40, 430), (12, 440)], [(529, 250), (531, 260), (533, 270), (536, 280), (538, 290), (540, 300), (540, 310), (538, 320), (536, 330), (531, 340), (525, 350), (519, 360), (513, 370), (507, 380), (499, 390), (491, 400), (483, 410), (475, 420), (467, 430), (459, 440), (451, 450), (443, 460), (435, 470), (426, 480), (418, 490), (410, 500), (402, 510), (394, 520), (386, 530), (378, 540), (370, 550), (362, 560), (354, 570), (346, 580), (338, 590), (330, 600), (322, 610), (314, 620), (306, 630), (297, 640), (289, 650), (281, 660), (273, 670), (265, 680), (257, 690), (249, 700), (241, 710)], [(553, 240), (568, 250), (583, 260), (598, 270), (613, 280), (640, 290), (667, 300), (693, 310), (719, 320), (740, 330), (761, 340), (783, 350), (804, 360), (825, 370), (846, 380), (868, 390), (883, 400), (897, 410), (912, 420), (926, 430), (941, 440), (955, 450), (969, 460), (984, 470), (998, 480), (1013, 490), (1027, 500), (1042, 510), (1056, 520), (1070, 530), (1085, 540), (1099, 550), (1114, 560), (1128, 570), (1143, 580), (1157, 590), (1171, 600), (1186, 610), (1200, 620), (1215, 630), (1229, 640), (1244, 650), (1258, 660), (1272, 670)], [(558, 240), (585, 250), (613, 260), (646, 270), (679, 280), (714, 290), (770, 300), (817, 310), (865, 320), (912, 330), (954, 340), (994, 350), (1033, 360), (1073, 370), (1113, 380), (1153, 390), (1193, 400), (1232, 410), (1272, 420)]], 'aug': False, 'y_samples': [160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710]}]

        if self.split == 'train': # 这里是val
            random.shuffle(self.annotations)
        print('total annos', len(self.annotations)) # 358
        self.max_lanes = max_lanes # 记录整个数据集中一张图片中最多的车道线数量。5

    def transform_annotations(self, transform):
        self.annotations = list(map(transform, self.annotations))

    def pred2tusimpleformat(self, idx, pred, runtime):
        runtime *= 1000.  # s to ms
        img_name = self.annotations[idx]['old_anno']['org_path']
        h_samples = self.annotations[idx]['old_anno']['y_samples']
        lanes = self.pred2lanes(img_name, pred, h_samples)
        output = {'raw_file': img_name, 'lanes': lanes, 'run_time': runtime}
        return json.dumps(output) # 返回json格式文件

    def save_tusimple_predictions(self, predictions, runtimes, filename):
        lines = []
        for idx in range(len(predictions)):
            line = self.pred2tusimpleformat(idx, predictions[idx], runtimes[idx])
            lines.append(line)
        with open(filename, 'w') as output_file: # '/tmp/tusimple_predictions_default_2695.json'。注意你写入的文件路径
            output_file.write('\n'.join(lines)) # 

    def eval(self, exp_dir, predictions, runtimes, label=None, only_metrics=False): # exp_dir:'experiments/default'
    # predictions:(358, 5, 7) # runtimes:358
        pred_filename = '/tmp/tusimple_predictions_{}.json'.format(label) # '/tmp/tusimple_predictions_default_2695.json'
        self.save_tusimple_predictions(predictions, runtimes, pred_filename) # 
        if self.metric == 'default':
            result = json.loads(LaneEval.bench_one_submit(pred_filename, self.anno_files[0])) # pred_filename:'/tmp/tusimple_predictions_default_2695.json'
            # self.anno_files:['/home/wqf/PolyLaneNet/PolyLaneNet-master/label_data_0531.json']
        elif self.metric == 'ours':
            result = json.loads(eval_json(pred_filename, self.anno_files[0], json_type='tusimple'))
        table = {}
        for metric in result:
            table[metric['name']] = [metric['value']]
        table = tabulate(table, headers='keys')

        if not only_metrics:
            filename = 'tusimple_{}_eval_result_{}.json'.format(self.split, label)
            with open(os.path.join(exp_dir, filename), 'w') as out_file:
                json.dump(result, out_file)

        return table, result

    def __getitem__(self, idx):
        return self.annotations[idx]

    def __len__(self):
        return len(self.annotations) # {ndarray:(358,)}

附上述代码导入的evaluator.py文件解读:

import sys

import numpy as np

from lib.datasets.lane_dataset import LaneDataset

EXPS_DIR = 'experiments'


class Evaluator(object):
    def __init__(self, dataset, exp_dir, poly_degree=3):
        self.dataset = dataset # {LaneDataset:358}
        # self.predictions = np.zeros((len(dataset.annotations), dataset.max_lanes, 4 + poly_degree))
        self.predictions = None
        self.runtimes = np.zeros(len(dataset)) # 调用类中__len__()魔术方法。返回长度。{ndarray:(358,)}
        self.loss = np.zeros(len(dataset)) # {ndarray:(358,)}
        self.exp_dir = exp_dir # 'experiments/default'
        self.new_preds = False

    def add_prediction(self, idx, pred, runtime):
        if self.predictions is None:
            self.predictions = np.zeros((len(self.dataset.annotations), pred.shape[1], pred.shape[2])) # (358, 5, 7)
        self.predictions[idx, :pred.shape[1], :] = pred # 把预测结果都放到self.predictions中
        self.runtimes[idx] = runtime # 
        self.new_preds = True

    def eval(self, **kwargs):# kwargs:{'label': 'default_2695'}
        return self.dataset.dataset.eval(self.exp_dir, self.predictions, self.runtimes, **kwargs) # self.exp_dir:'experiments/default'
        # self.predictions:(358, 5, 7) self.runtimes:358


if __name__ == "__main__":
    evaluator = Evaluator(LaneDataset(split='test'), exp_dir=sys.argv[1])
    evaluator.tusimple_eval()

上述代码导入的lane.py文件解读如下所示:

import numpy as np
import ujson as json
from sklearn.linear_model import LinearRegression


class LaneEval(object):
    lr = LinearRegression()
    pixel_thresh = 20
    pt_thresh = 0.85

    @staticmethod
    def get_angle(xs, y_samples):
        xs, ys = xs[xs >= 0], y_samples[xs >= 0]
        if len(xs) > 1:
            LaneEval.lr.fit(ys[:, None], xs)
            k = LaneEval.lr.coef_[0]
            theta = np.arctan(k)
        else:
            theta = 0
        return theta

    @staticmethod
    def line_accuracy(pred, gt, thresh):
        pred = np.array([p if p >= 0 else -100 for p in pred])
        gt = np.array([g if g >= 0 else -100 for g in gt])
        return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt)

    @staticmethod
    def distances(pred, gt):
        return np.abs(pred - gt)

    @staticmethod
    def bench(pred, gt, y_samples, running_time, get_matches=False):
        if any(len(p) != len(y_samples) for p in pred):
            raise Exception('Format of lanes error.')
        if running_time > 20000 or len(gt) + 2 < len(pred):
            return 0., 0., 1.
        angles = [LaneEval.get_angle(np.array(x_gts), np.array(y_samples)) for x_gts in gt]
        threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles]
        line_accs = []
        fp, fn = 0., 0.
        matched = 0.
        my_matches = [False] * len(pred)
        my_accs = [0] * len(pred)
        my_dists = [None] * len(pred)
        for x_gts, thresh in zip(gt, threshs):
            accs = [LaneEval.line_accuracy(np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred]
            my_accs = np.maximum(my_accs, accs)
            max_acc = np.max(accs) if len(accs) > 0 else 0.
            my_dist = [LaneEval.distances(np.array(x_preds), np.array(x_gts)) for x_preds in pred]
            if len(accs) > 0:
                my_dists[np.argmax(accs)] = {
                    'y_gts': list(np.array(y_samples)[np.array(x_gts) >= 0].astype(int)),
                    'dists': list(my_dist[np.argmax(accs)])
                }

            if max_acc < LaneEval.pt_thresh:
                fn += 1
            else:
                my_matches[np.argmax(accs)] = True
                matched += 1
            line_accs.append(max_acc)
        fp = len(pred) - matched
        if len(gt) > 4 and fn > 0:
            fn -= 1
        s = sum(line_accs)
        if len(gt) > 4:
            s -= min(line_accs)
        if get_matches:
            return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(
                min(len(gt), 4.), 1.), my_matches, my_accs, my_dists
        return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.), 1.)

    @staticmethod
    def bench_one_submit(pred_file, gt_file): # 这个函数是关于评估tusimple的,之前用过,这里忘记了,就先不看了,用到了再好好研究。
    # pred_file:'/tmp/tusimple_predictions_default_2695.json'
    # gt_file:'/home/wqf/PolyLaneNet/PolyLaneNet-master/label_data_0531.json'
        try:
            json_pred = [json.loads(line) for line in open(pred_file).readlines()]
        except BaseException as e:
            raise Exception('Fail to load json file of the prediction.')
        json_gt = [json.loads(line) for line in open(gt_file).readlines()]
        if len(json_gt) != len(json_pred):
            raise Exception('We do not get the predictions of all the test tasks')
        gts = {l['raw_file']: l for l in json_gt}
        accuracy, fp, fn = 0., 0., 0.
        run_times = []
        for pred in json_pred:
            if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred:
                raise Exception('raw_file or lanes or run_time not in some predictions.')
            raw_file = pred['raw_file']
            pred_lanes = pred['lanes']
            run_time = pred['run_time']
            run_times.append(run_time)
            if raw_file not in gts:
                raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
            gt = gts[raw_file]
            gt_lanes = gt['lanes']
            y_samples = gt['h_samples']
            try:
                a, p, n = LaneEval.bench(pred_lanes, gt_lanes, y_samples, run_time)
            except BaseException as e:
                raise Exception('Format of lanes error.')
            accuracy += a
            fp += p
            fn += n
        num = len(gts)
        # the first return parameter is the default ranking parameter
        return json.dumps([{
            'name': 'Accuracy',
            'value': accuracy / num,
            'order': 'desc'
        }, {
            'name': 'FP',
            'value': fp / num,
            'order': 'asc'
        }, {
            'name': 'FN',
            'value': fn / num,
            'order': 'asc'
        }, {
            'name': 'FPS',
            'value': 1000. / np.mean(run_times)
        }])


if __name__ == '__main__':
    import sys
    try:
        if len(sys.argv) != 3:
            raise Exception('Invalid input arguments')
        print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2]))
    except Exception as e:
        print(e)
        # sys.exit(e.message)

完结。

你可能感兴趣的:(车道线检测/道路边缘检测,深度学习,计算机视觉,python,自动驾驶)