(可能忽略global information;但全局信息在有遮挡和阴影时是十分重要的)
预测的置信度cj∈[0; 1]
backbone network:用于提取特征
fully connected layer输出:Mmax + 1个
(第1; ……;第Mmax:用于车道线标记预测)
( 第Mmax + 1:用于输出h)
由于PolyLaneNet采用多项式来代表车道线而不是用一系列点,因此,对每一个输出j,j = 1;……;Mmax,模型估计的是多项式的系数。
对于给定图片,M是图中标注好的车道线数量。一般来讲,M<=4满足了大多数交通场景的要求。神经单元j代表了车道线j(j = 1;……;M)。因此第M+1到Mmax个输出在损失函数中应该不予考虑。
(3,268张用于训练, 358张用于验证,2,782张用于测试)
训练部分跑了2695 epochs,在Titan V上用了35个小时。
batch size=16
Ws = Wc = Wh = 1;Wp = 300 (调出来的)
来自TuSimple’s benchmark。
accuracy (Acc), false positive (FP) and false negative (FN) rates。
一个预测的车道线标记满足如下条件则被认为是true positive(即正确):
Lane Position Deviation(LPD):更好地捕捉模型对自车远近视角的精度。这是对自我车道线的预测和groundtruth误差。
从用于表示车道标记的多项式度来看,使用低阶多项式时,精度上的微小差异表明数据集是多么的不平衡。使用一阶多项式(即直线)仅降低了0.35 p.p。
R. K. Satzoda and M. M. Trivedi, “On Performance Evaluation Metrics for Lane Estimation,” 2014 22nd International Conference on Pattern Recognition, Stockholm, Sweden, 2014, pp. 2625-2630, doi: 10.1109/ICPR.2014.453.
用于确定道路场景中估计车道的准确性。尤其是从 ego-vehicle的视角来看。参考图2,车道特征(图2中的(b))用于估计车道,其由图2中的虚线(d)表示。车道位置偏差(图2中的(e))测量通过连接实际车道标线(图2中的(h))获得的检测车道(d)与实际车道的偏差。
FLOPS 注意全部大写 是floating point of per second的缩写,意指每秒浮点运算次数。用来衡量硬件的性能。
FLOPs 是floating point of operations的缩写,是浮点运算次数,可以用来衡量算法/模型复杂度。
衡量计算量除了FLOPs外还有一种概念是求MACs(Multiply Accumulate)乘积累加运算次数,一次乘积,然后把这个乘积和另外一个数求和就叫一次MAC,显然与上面计算结果的关系就在于是否要乘2的关系。
MACS 每秒执行的定点乘累加操作次数的缩写,它是衡量计算机定点处理能力的量,这个量经常用在那些需要大量定点乘法累加运算的科学运算中,记为MACS。
该论文被ICPR 2020接收了。
官方是Python 3.5.2,但更高版本理论兼容。
pip install -r requirements.txt
# Training settings
exps_dir: 'experiments' # 实验目录的根目录路径(不仅是您要运行的目录)
iter_log_interval: 1 # 每N次迭代记录一次训练迭代
iter_time_window: 100 # 移动平均迭代窗口用于打印loss指标
model_save_interval: 1 # 每N个epochs保存一次模型
seed: 0 # 随机数种子
backup: drive:polylanenet-experiments # 训练结束后,将使用rclone自动上传实验目录。 如果您不想这样做,请留空。
name: PolyRegression
num_outputs: 35 # (5 lanes) * (1 conf + 2 (upper & lower) + 4 poly coeffs)
pretrained: true
backbone: 'efficientnet-b0'
pred_category: false
conf_weight: 1
lower_weight: 1
upper_weight: 1
cls_weight: 0
poly_weight: 300
batch_size: 16
epochs: 2695
name: Adam
lr: 3.0e-4
name: CosineAnnealingLR
T_max: 385
# Testing settings
conf_threshold: 0.5 # 将置信度低于此置信度的预测设置为0(即,将其设置为对于指标无效)
# Dataset settings
type: PointsDataset
dataset: tusimple
split: train
img_size: [360, 640]
normalize: true
aug_chance: 0.9090909090909091 # 10/11
augmentations: # ImgAug augmentations
- name: Affine
rotate: !!python/tuple [-10, 10]
- name: HorizontalFlip
p: 0.5
- name: CropToFixedSize
width: 1152
height: 648
root: "datasets/tusimple" # Dataset root
test: &test
type: PointsDataset
dataset: tusimple
split: val
img_size: [360, 640]
root: "datasets/tusimple"
normalize: true
augmentations: []
# val = test
val = test
<<: *test
python train.py --exp_name tusimple --cfg config.yaml
--exp_name 实验名称
--cfg Config file for the training (.yaml)
--resume 断点续训。 如果训练过程被中断,请使用相同的参数和此选项再次运行它,以从最后一个检查点恢复训练。
--validate Wheter to validate during the training session. Was not in our experiments, which means it has not been thoroughly tested.
--deterministic set cudnn.deterministic = True and cudnn.benchmark = False
python test.py --exp_name tusimple --cfg config.yaml --epoch 2695
--exp_name Experiment name.
--cfg Config file for the test (.yaml). (probably the same one used in the training)
--epoch EPOCH Epoch to test the model on
--batch_size Number of images per batch
--view Show predictions. Will draw the predictions in an image and then show it (cv.imshow)
要复现结果,您可以使用相同的设置重新训练模型(其结果应与文中的设置非常接近),也可以仅测试模型。 如果要重新训练,则仅需要修改相应的YAML设置文件,您可以在cfgs
目录中找到该文件。 如果您只想通过测试模型来重现文中的确切指标,则必须:
文件中修改所有与路径相关的字段(即,dataset paths和exps_dir
python test.py --exp_name $exp_name --cfg $exps_dir/$exp_name/config.yaml --epoch 2695
将$ exp_name
替换为您下载的目录的名称(实验的名称),并将$ exps_dir
值。 该脚本将查找名为$ exps_dir / $ exp_name / models
import os
import sys
import random
import logging
import argparse # 关于argparse模块的用法:https://docs.python.org/zh-cn/dev/library/argparse.html
import subprocess
from time import time
import cv2
import numpy as np
import torch
from lib.config import Config
from utils.evaluator import Evaluator
def test(model, test_loader, evaluator, exp_root, cfg, view, epoch, max_batches=None, verbose=True):
if verbose:
logging.info("Starting testing.")
# Test the model
if epoch > 0:
model.load_state_dict(torch.load(os.path.join(exp_root, "models", "model_{:03d}.pt".format(epoch)))['model'])
criterion_parameters = cfg.get_loss_parameters()
test_parameters = cfg.get_test_parameters()
criterion = model.loss
loss = 0
total_iters = 0
test_t0 = time()
loss_dict = {}
with torch.no_grad():
for idx, (images, labels, img_idxs) in enumerate(test_loader):
if max_batches is not None and idx >= max_batches:
if idx % 1 == 0 and verbose:
logging.info("Testing iteration: {}/{}".format(idx + 1, len(test_loader)))
images = images.to(device)
labels = labels.to(device)
t0 = time()
outputs = model(images)
t = time() - t0
loss_i, loss_dict_i = criterion(outputs, labels, **criterion_parameters)
loss += loss_i.item()
total_iters += 1
for key in loss_dict_i:
if key not in loss_dict:
loss_dict[key] = 0
loss_dict[key] += loss_dict_i[key]
outputs = model.decode(outputs, labels, **test_parameters)
if evaluator is not None:
lane_outputs, _ = outputs
evaluator.add_prediction(img_idxs, lane_outputs.cpu().numpy(), t / images.shape[0])
if view:
outputs, extra_outputs = outputs
preds = test_loader.dataset.draw_annotation(
cls_pred=extra_outputs[0].cpu().numpy() if extra_outputs is not None else None)
cv2.imshow('pred', preds)
if verbose:
logging.info("Testing time: {:.4f}".format(time() - test_t0))
out_line = []
for key in loss_dict:
loss_dict[key] /= total_iters
out_line.append('{}: {:.4f}'.format(key, loss_dict[key]))
if verbose:
logging.info(', '.join(out_line))
return evaluator, loss / total_iters
def parse_args():
parser = argparse.ArgumentParser(description="Lane regression")
parser.add_argument("--exp_name", default="default", help="Experiment name", required=True)
parser.add_argument("--cfg", default="config.yaml", help="Config file", required=True)
parser.add_argument("--epoch", type=int, default=None, help="Epoch to test the model on")
parser.add_argument("--batch_size", type=int, help="Number of images per batch")
parser.add_argument("--view", action="store_true", help="Show predictions")
return parser.parse_args()
def get_code_state():
state = "Git hash: {}".format(
subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
state += '\n*************\nGit diff:\n*************\n'
state += subprocess.run(['git', 'diff'], stdout=subprocess.PIPE).stdout.decode('utf-8')
return state
def log_on_exception(exc_type, exc_value, exc_traceback):
logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
if __name__ == "__main__":
args = parse_args()
cfg = Config(args.cfg)
# Set up seeds
# Set up logging
exp_root = os.path.join(cfg['exps_dir'], os.path.basename(os.path.normpath(args.exp_name)))
format="[%(asctime)s] [%(levelname)s] %(message)s",
logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
sys.excepthook = log_on_exception
logging.info("Experiment name: {}".format(args.exp_name))
logging.info("Config:\n" + str(cfg))
logging.info("Args:\n" + str(args))
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Hyper parameters
num_epochs = cfg["epochs"]
batch_size = cfg["batch_size"] if args.batch_size is None else args.batch_size
# Model
model = cfg.get_model().to(device)
test_epoch = args.epoch
# Get data set
test_dataset = cfg.get_dataset("test")
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size if args.view is False else 1,
# Eval results
evaluator = Evaluator(test_loader.dataset, exp_root)
format="[%(asctime)s] [%(levelname)s] %(message)s",
logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
logging.info('Code state:\n {}'.format(get_code_state()))
_, mean_loss = test(model, test_loader, evaluator, exp_root, cfg, epoch=test_epoch, view=args.view)
logging.info("Mean test loss: {:.4f}".format(mean_loss))
evaluator.exp_name = args.exp_name
eval_str, _ = evaluator.eval(label='{}_{}'.format(os.path.basename(args.exp_name), test_epoch))
import os
import sys
import random
import logging
import argparse
import subprocess
from time import time
import cv2
import numpy as np
import torch
from lib.config import Config
from utils.evaluator import Evaluator
def test(model, test_loader, evaluator, exp_root, cfg, view, epoch, max_batches=None, verbose=True):
# model:模型结构
# test_loader:{DataLoader:23}
# evaluator:
# exp_root:'experiments/default'
# cfg:yaml配置文件
# view:False
# epoch:2695
# max_batches:None
# verbose:True
if verbose:
logging.info("Starting testing.")
# Test the model
if epoch > 0:
model.load_state_dict(torch.load(os.path.join(exp_root, "models", "model_{:03d}.pt".format(epoch)))['model']) # 载入模型。根据报错信息去修改
model.eval() # 告诉模型在推理
criterion_parameters = cfg.get_loss_parameters() # {'conf_weight': 1, 'lower_weight': 1, 'upper_weight': 1, 'cls_weight': 0, 'poly_weight': 300}
test_parameters = cfg.get_test_parameters() # {'conf_threshold': 0.5}
criterion = model.loss # 调用model的loss函数。用于后续的loss计算
loss = 0
total_iters = 0
test_t0 = time() # 记录测试开始时间
loss_dict = {}
with torch.no_grad(): # 好像是被warp起来的部分不进行梯度更新,但是还是参与了计算的(深度学习基础知识,得补)
# https://blog.csdn.net/weixin_46559271/article/details/105658654
# https://www.jianshu.com/p/1cea017f5d11
# https://blog.csdn.net/weixin_44134757/article/details/105775027
for idx, (images, labels, img_idxs) in enumerate(test_loader): # 关于enumerate:https://www.runoob.com/python/python-func-enumerate.html
# test_loader:
# test_loader.dataset:
# test34 lane_dataset222 tusimple137 lane_dataset203 (为什么是这样一个执行顺序,我想是调用enumerate函数时,需要先知道test_loader的长度,然后再用[]方法去取值
# 先执行enumerate,返回一个enumerate对象。
# idx是一个batch的索引(一个batch是16,这里358//16=23。)
# images:{Tensor:16}
# labels:{Tensor:16}
# img_idxs是图片的索引。
if max_batches is not None and idx >= max_batches:
if idx % 1 == 0 and verbose:
logging.info("Testing iteration: {}/{}".format(idx + 1, len(test_loader)))
images = images.to(device)
labels = labels.to(device) # 放到GPU上
t0 = time()
outputs = model(images) # 将图像输入模型,记录输出 虽然{tuple:2}。但是16x35(16张图片,每张图片35个输出)
# 一个例子如下:
# (tensor([[ 1.9717e+01, 3.5665e-01, 5.9289e-01, 6.0696e-01, -1.2504e+00,
-9.5759e-01, 8.3573e-01, 1.9728e+01, 1.7501e-02, 8.4806e-01,
3.0594e-01, -5.8771e-01, -3.6744e-01, 6.2564e-01, 1.4671e+01,
1.8156e-02, 8.9165e-01, -2.3603e-01, 1.1759e-01, 7.8406e-01,
1.9520e-01, 5.1091e-02, -3.0347e-02, 6.6394e-01, -8.3053e-01,
8.3323e-01, 1.3598e+00, -2.2746e-02, -1.1667e+01, 2.9069e-02,
5.9165e-01, -1.0211e-01, 7.8156e-01, 9.9158e-01, 1.5408e-01],
[ 2.0524e+01, 3.5466e-01, 5.9829e-01, 9.8226e-01, -1.1990e+00,
-1.4157e+00, 1.0521e+00, 2.0533e+01, -5.5082e-04, 1.0035e+00,
1.8302e-01, -2.6475e-01, -5.1310e-01, 6.9090e-01, 1.3714e+01,
3.1270e-02, 9.0050e-01, -2.2248e-01, 4.2548e-01, 5.4658e-01,
3.0025e-01, -5.9779e+00, -2.1180e-02, 5.4276e-01, -5.6866e-01,
1.2482e+00, 1.2983e+00, 2.0630e-02, -1.2991e+01, 1.8783e-02,
5.1402e-01, 1.3041e-01, 9.1163e-01, 1.0488e+00, 2.3146e-01],
[ 1.9811e+01, 3.4818e-01, 6.0123e-01, 1.2543e+00, -1.1373e+00,
-1.9976e+00, 1.3299e+00, 1.9818e+01, -1.3726e-02, 9.8255e-01,
1.6025e-01, -4.9490e-04, -9.2998e-01, 8.9230e-01, 9.5679e+00,
3.7305e-02, 9.9070e-01, -1.2178e-01, 5.5367e-01, -1.1417e-01,
5.9704e-01, -1.2919e+01, -1.1187e-02, 5.6699e-01, 2.8218e-01,
9.4523e-01, 2.8753e-01, 4.6897e-01, -1.2240e+01, 6.4192e-03,
5.0729e-01, 4.6846e-01, 4.4547e-01, 4.8933e-01, 5.2881e-01],
[ 1.8849e+01, 3.5481e-01, 5.9489e-01, 1.1236e+00, -1.0921e+00,
-1.9187e+00, 1.2948e+00, 1.8853e+01, -8.1959e-03, 9.8278e-01,
1.4345e-01, -1.6455e-02, -9.2424e-01, 8.8395e-01, 9.0486e+00,
3.7280e-02, 1.0009e+00, -6.9186e-02, 5.1333e-01, -1.8387e-01,
6.1838e-01, -6.3384e+00, -7.4000e-03, 6.0970e-01, 3.7614e-01,
7.7696e-01, 1.2253e-01, 5.2114e-01, -1.0654e+01, 2.0078e-03,
5.2485e-01, 3.7526e-01, 4.0945e-01, 4.7751e-01, 5.0353e-01],
[ 1.8772e+01, 3.2760e-01, 5.7178e-01, 1.3995e+00, -1.1920e+00,
-2.1902e+00, 1.3575e+00, 1.8774e+01, -8.2448e-03, 9.6515e-01,
2.2256e-01, -1.8641e-02, -1.1020e+00, 9.3922e-01, 8.3400e+00,
3.5112e-02, 9.8313e-01, -3.7752e-02, 5.1563e-01, -3.2519e-01,
6.7639e-01, -9.9023e+00, -7.2143e-03, 6.2546e-01, 5.8917e-01,
6.4413e-01, -1.6515e-01, 6.6375e-01, -7.5666e+00, 1.9182e-03,
5.3885e-01, 4.7531e-01, 2.1594e-01, 2.6025e-01, 6.4260e-01],
[ 1.8206e+01, 3.2264e-01, 6.0505e-01, 1.1304e+00, -9.7306e-01,
-1.8966e+00, 1.2448e+00, 1.8205e+01, -1.2127e-02, 9.9538e-01,
1.3683e-01, 8.2247e-02, -9.3255e-01, 8.7522e-01, 7.1296e+00,
3.7168e-02, 9.7917e-01, -7.7272e-02, 6.0907e-01, -2.0614e-01,
6.3528e-01, -1.3433e+01, -8.9997e-03, 5.7927e-01, 5.7465e-01,
7.5690e-01, -4.6105e-02, 6.2752e-01, -9.0991e+00, 2.3277e-03,
5.1762e-01, 5.5586e-01, 2.6600e-01, 2.8972e-01, 6.5162e-01],
[ 1.8724e+01, 3.1786e-01, 5.7596e-01, 1.3195e+00, -1.0658e+00,
-2.0516e+00, 1.2803e+00, 1.8728e+01, -1.3836e-02, 9.7149e-01,
1.7835e-01, -1.4811e-03, -1.0227e+00, 9.0505e-01, 7.4657e+00,
3.5839e-02, 9.9364e-01, -8.7356e-02, 5.4268e-01, -2.3358e-01,
6.4212e-01, -1.4904e+01, -7.8026e-03, 6.0794e-01, 5.7523e-01,
6.7591e-01, -1.1579e-01, 6.5063e-01, -1.0407e+01, 1.2761e-03,
5.3218e-01, 6.0676e-01, 1.9660e-01, 1.8991e-01, 6.7666e-01],
[ 1.9729e+01, 3.8188e-01, 1.0091e+00, 2.2574e-01, -3.9591e-01,
-3.3652e-01, 6.8122e-01, 1.9735e+01, 1.4888e-02, 9.7211e-01,
-1.8558e-01, 3.4326e-01, 5.1835e-01, 3.1772e-01, 1.1779e+01,
1.6677e-02, 6.0190e-01, -8.7055e-01, 9.8156e-01, 1.5660e+00,
-9.3855e-02, -7.9949e+00, -1.7439e-02, 5.0544e-01, -5.9118e-01,
1.1279e+00, 1.4808e+00, -4.8462e-03, -1.1008e+01, 1.4023e-02,
4.9886e-01, 1.2144e-01, 9.3424e-01, 1.1006e+00, 1.8501e-01],
[ 2.0828e+01, 3.7389e-01, 9.5909e-01, 4.0849e-01, -6.5877e-01,
-5.4334e-01, 7.9542e-01, 2.0835e+01, 1.8135e-02, 1.0151e+00,
-9.3961e-02, 1.8864e-01, 3.3228e-01, 4.2197e-01, 1.4412e+01,
1.3447e-02, 5.9736e-01, -8.0056e-01, 8.6799e-01, 1.4857e+00,
-2.8108e-02, 2.2950e+00, -1.7361e-02, 4.9760e-01, -9.8458e-01,
1.5259e+00, 1.9876e+00, -2.2419e-01, -9.7860e-01, 1.4131e-02,
4.5145e-01, -2.6719e-01, 1.6454e+00, 1.9343e+00, -1.7758e-01],
[ 1.8541e+01, 3.6052e-01, 9.8302e-01, 1.2630e-01, -3.6989e-01,
-4.0694e-01, 6.9298e-01, 1.8542e+01, 1.2207e-02, 1.0018e+00,
-1.6262e-01, 3.3378e-01, 3.3623e-01, 4.0727e-01, 1.0124e+01,
1.9962e-02, 5.8260e-01, -8.2206e-01, 9.8333e-01, 1.3858e+00,
3.0251e-02, -6.4354e+00, -1.2738e-02, 4.9564e-01, -4.4351e-01,
1.0422e+00, 1.2766e+00, 1.4124e-01, -5.1403e+00, 1.1081e-02,
4.7627e-01, 6.2715e-02, 9.4187e-01, 1.1648e+00, 2.3920e-01],
[ 2.1700e+01, 3.6963e-01, 1.0161e+00, 3.5365e-01, -5.1549e-01,
-3.8895e-01, 7.3650e-01, 2.1710e+01, 1.2069e-02, 9.6929e-01,
-1.0482e-01, 2.9372e-01, 4.4866e-01, 3.7781e-01, 1.8020e+01,
1.0207e-02, 5.3495e-01, -8.2141e-01, 1.0396e+00, 1.6766e+00,
-1.0593e-01, 6.8139e+00, -2.0878e-02, 4.6487e-01, -1.3351e+00,
1.8629e+00, 2.4661e+00, -4.0683e-01, -3.9403e+00, 8.8815e-03,
4.3651e-01, -3.1911e-01, 2.0130e+00, 2.3060e+00, -3.2863e-01],
[ 1.9982e+01, 3.5527e-01, 9.9771e-01, 1.8310e-01, -3.2837e-01,
-2.6487e-01, 6.1877e-01, 1.9988e+01, 1.3263e-02, 9.2778e-01,
-2.2666e-01, 3.8196e-01, 6.1604e-01, 2.7222e-01, 1.1386e+01,
2.2135e-02, 5.5518e-01, -1.1819e+00, 1.1790e+00, 1.9540e+00,
-2.2913e-01, -8.7199e+00, -1.3203e-02, 5.0026e-01, -6.2586e-01,
1.1869e+00, 1.5888e+00, -8.7219e-03, -9.9509e+00, 1.3846e-02,
4.7484e-01, 1.0665e-01, 1.0456e+00, 1.2277e+00, 1.8260e-01],
[ 1.9321e+01, 3.5748e-01, 1.0011e+00, 8.5000e-02, -3.1382e-01,
-5.8064e-02, 5.0936e-01, 1.9330e+01, 1.5940e-02, 9.3347e-01,
-2.7937e-01, 3.4139e-01, 7.9552e-01, 1.9267e-01, 1.1115e+01,
1.9466e-02, 5.5904e-01, -1.1650e+00, 1.0820e+00, 2.0302e+00,
-2.5906e-01, -8.1233e+00, -1.4825e-02, 5.1039e-01, -6.9422e-01,
1.1615e+00, 1.6638e+00, -7.1307e-02, -1.1717e+01, 1.7474e-02,
4.8936e-01, 1.3536e-01, 1.0647e+00, 1.1938e+00, 1.5716e-01],
[ 1.9986e+01, 3.4426e-01, 1.0036e+00, 6.0444e-02, -3.1636e-01,
-5.3784e-02, 5.1158e-01, 1.9998e+01, 1.5422e-02, 9.3486e-01,
-2.3981e-01, 3.1167e-01, 7.6142e-01, 2.0912e-01, 1.2843e+01,
1.9240e-02, 5.4574e-01, -1.3053e+00, 1.1310e+00, 2.1718e+00,
-3.1241e-01, -9.2920e+00, -1.1794e-02, 5.1373e-01, -7.2432e-01,
1.2056e+00, 1.7307e+00, -8.6010e-02, -1.2046e+01, 1.8683e-02,
4.7902e-01, 9.8536e-02, 1.0589e+00, 1.2306e+00, 1.6226e-01],
[ 1.9049e+01, 3.6031e-01, 9.8259e-01, 1.7886e-01, -3.9499e-01,
-3.9496e-01, 6.7416e-01, 1.9055e+01, 1.4615e-02, 9.9518e-01,
-1.4237e-01, 3.1150e-01, 3.9308e-01, 3.6029e-01, 1.0483e+01,
2.2869e-02, 6.1555e-01, -8.7037e-01, 9.7719e-01, 1.4753e+00,
-4.7680e-02, -1.0114e+01, -1.3152e-02, 5.2934e-01, -3.8377e-01,
1.0160e+00, 1.2257e+00, 1.1184e-01, -9.7097e+00, 1.3663e-02,
4.9965e-01, 2.1355e-01, 8.3867e-01, 9.5725e-01, 2.7290e-01],
[ 2.0449e+01, 3.6672e-01, 9.9850e-01, 2.1467e-01, -3.9144e-01,
-3.1768e-01, 6.5340e-01, 2.0459e+01, 1.5070e-02, 9.6099e-01,
-1.6301e-01, 3.2221e-01, 5.1998e-01, 3.1643e-01, 1.2633e+01,
2.5129e-02, 5.4299e-01, -1.2838e+00, 1.2315e+00, 2.0434e+00,
-2.6675e-01, -9.3589e+00, -9.6726e-03, 5.0069e-01, -5.9143e-01,
1.2203e+00, 1.5694e+00, -2.2036e-02, -1.2009e+01, 1.6047e-02,
4.7439e-01, 1.4823e-01, 1.0121e+00, 1.1582e+00, 2.1145e-01]],
device='cuda:0'), None)
t = time() - t0
loss_i, loss_dict_i = criterion(outputs, labels, **criterion_parameters) # output:{16, 35}.labels:{16, 5, 115}
# **criterion_parameters:{'conf_weight': 1, 'lower_weight': 1, 'upper_weight': 1, 'cls_weight': 0, 'poly_weight': 300}
# loss_i:tensor(3.9538, device='cuda:0')
# loss_dict_i:{'conf': tensor(0.2224, device='cuda:0'), 'lower': tensor(0.0003, device='cuda:0'), 'upper': tensor(0.0023, device='cuda:0'), 'poly': tensor(3.7288, device='cuda:0'), 'cls_loss': 0}
loss += loss_i.item() # 记录epoch的损失
total_iters += 1 # iter次数
for key in loss_dict_i:
if key not in loss_dict:
loss_dict[key] = 0
loss_dict[key] += loss_dict_i[key] # 记录损失
outputs = model.decode(outputs, labels, **test_parameters) # 调用函数
if evaluator is not None:
lane_outputs, _ = outputs # torch.Size([16, 5, 7])
evaluator.add_prediction(img_idxs, lane_outputs.cpu().numpy(), t / images.shape[0]) # 图片的索引。img_idxs:tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
# t:模型将图片作为输入后,预测输出所需时间 0.7983336448669434
# images:torch.Size([16, 3, 360, 640])
if view:
outputs, extra_outputs = outputs
preds = test_loader.dataset.draw_annotation(
cls_pred=extra_outputs[0].cpu().numpy() if extra_outputs is not None else None)
cv2.imshow('pred', preds)
if verbose:
logging.info("Testing time: {:.4f}".format(time() - test_t0))
out_line = []
for key in loss_dict:
loss_dict[key] /= total_iters
out_line.append('{}: {:.4f}'.format(key, loss_dict[key]))
if verbose:
logging.info(', '.join(out_line))
return evaluator, loss / total_iters
def parse_args():
parser = argparse.ArgumentParser(description="Lane regression") # 1.创建一个解析器
parser.add_argument("--exp_name", default="default", help="Experiment name") # 2.这个和下面几个都是添加参数的。关于argparse模块的使用单独学习过,这里不再赘述。
parser.add_argument("--cfg", default="tusimple.yaml", help="Config file") # 配置文件,这里也有修改
parser.add_argument("--epoch", type=int, default=2695, help="Epoch to test the model on") # epoch,同样修改
parser.add_argument("--batch_size", type=int, help="Number of images per batch")
parser.add_argument("--view", action="store_true", help="Show predictions")
return parser.parse_args()
def get_code_state():
state = "Git hash: {}".format(
subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
state += '\n*************\nGit diff:\n*************\n'
state += subprocess.run(['git', 'diff'], stdout=subprocess.PIPE).stdout.decode('utf-8')
return state
if __name__ == "__main__": # 从函数入口开始调试
args = parse_args() # args用来接收parse_args的返回值
cfg = Config(args.cfg) # args.cfg='tusimple.yaml'。将其传入Config
# Set up seeds
torch.manual_seed(cfg['seed']) # 这几个都是用来设置随机数种子。
np.random.seed(cfg['seed']) # 由于Config类中设置了__getitem__魔术方法,所以可以用这种方法去读取
# Set up logging # 设置日志文件
exp_root = os.path.join(cfg['exps_dir'], os.path.basename(os.path.normpath(args.exp_name))) # exp_root='experiments/default'
logging.basicConfig( # logging模块中最简单的用法。附近博文也有单独写过。
format="[%(asctime)s] [%(levelname)s] %(message)s",
logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
sys.excepthook = log_on_exception # 将log_on_exception函数的返回值赋值给sys.excepthook。(这个模块一直不太懂,有缘再学)
logging.info("Experiment name: {}".format(args.exp_name)) # 设置日志信息。
logging.info("Config:\n" + str(cfg))
logging.info("Args:\n" + str(args))
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 配置GPU。这里还要再搞一下,之前可以用的,现在可能软件版本什么的不兼容了。
# Hyper parameters
num_epochs = cfg["epochs"] # 2695
batch_size = cfg["batch_size"] if args.batch_size is None else args.batch_size # 16。这两个超参数都在yaml文件中配置过了。
# Model
model = cfg.get_model().to(device) # 调用Config类中的get_model()方法。
test_epoch = args.epoch # 2695
# Get data set
test_dataset = cfg.get_dataset("test") # {LaneDataset:358}
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size if args.view is False else 1,
num_workers=8) # {DataLoader:23}
# Eval results
evaluator = Evaluator(test_loader.dataset, exp_root) # exp_root='experiments/default'。实例化一个类
logging.basicConfig( # logging模块的使用
format="[%(asctime)s] [%(levelname)s] %(message)s",
logging.FileHandler(os.path.join(exp_root, "test_log.txt")),
logging.info('Code state:\n {}'.format(get_code_state())) # 同样是日志文件。Code state:
# Git hash:
# *************
# Git diff:
# *************
_, mean_loss = test(model, test_loader, evaluator, exp_root, cfg, epoch=test_epoch, view=args.view) #
logging.info("Mean test loss: {:.4f}".format(mean_loss)) # 5.2436
evaluator.exp_name = args.exp_name
eval_str, _ = evaluator.eval(label='{}_{}'.format(os.path.basename(args.exp_name), test_epoch)) # https://www.cnblogs.com/baxianhua/p/10214263.html
# Training settings
exps_dir: 'experiments'
iter_log_interval: 1
iter_time_window: 100
model_save_interval: 1
seed: 1
name: PolyRegression
num_outputs: 35 # (5 lanes) * (1 conf + 2 (upper & lower) + 4 poly coeffs)
pretrained: true
backbone: 'efficientnet-b0'
pred_category: false
curriculum_steps: [0, 0, 0, 0]
conf_weight: 1 # Wc
lower_weight: 1 # Ws
upper_weight: 1 # Wh
cls_weight: 0
poly_weight: 300 # Wp
batch_size: 16
epochs: 2695
name: Adam
lr: 3.0e-4
name: CosineAnnealingLR
T_max: 385
# Testing settings
conf_threshold: 0.5
# Dataset settings
type: LaneDataset
dataset: tusimple
split: train
img_size: [360, 640]
normalize: true
aug_chance: 0.9090909090909091 # 10/11
- name: Affine
rotate: !!python/tuple [-10, 10]
- name: HorizontalFlip
p: 0.5
- name: CropToFixedSize
width: 1152
height: 648
root: "/home/wqf/PolyLaneNet/PolyLaneNet-master"
test: &test
type: LaneDataset
dataset: tusimple
split: val
max_lanes: 5
img_size: [360, 640]
root: "/home/wqf/PolyLaneNet/PolyLaneNet-master"
normalize: true
augmentations: []
# val = test
<<: *test
import yaml # 关于PyYAML模块,下载时叫PyYAML,import时则直接叫yaml
import torch
import lib.models as models
import lib.datasets as datasets
class Config(object):
def __init__(self, config_path): # config_path='tusimple.yaml'
self.config = {}
self.load(config_path) # 初始化时调用load函数
def load(self, path):
with open(path, 'r') as file: # with...as...上下文管理器
self.config_str = file.read() # 读取文件
self.config = yaml.load(self.config_str, Loader=yaml.FullLoader) # 载入yaml文件。关于Loader这个参数:https://blog.csdn.net/ly021499/article/details/89026860。(就是更加安全的载入吧)
def __repr__(self):
return self.config_str
def get_dataset(self, split):
return getattr(datasets,
self.config['datasets'][split]['type'])(**self.config['datasets'][split]['parameters']) # 返回一个lib.datasets.lane_dataset.LaneDataset object
# 这里相当于调用datasets.self.config['datasets'][split]['type']。即调用datasets.LaneDataset类
# (**self.config['datasets'][split]['parameters'])):{'dataset': 'tusimple', 'split': 'val', 'max_lanes': 5, 'img_size': [360, 640], 'root': '/home/wqf/PolyLaneNet/PolyLaneNet-master', 'normalize': True, 'augmentations': []} # 至于这里的split
# 为什么是val。我也不太懂,源码中就这样写的。先这样读下去吧。
def get_model(self):
name = self.config['model']['name'] # 'PolyRegression'
parameters = self.config['model']['parameters'] # {'num_outputs': 35, 'pretrained': True, 'backbone': 'efficientnet-b0', 'pred_category': False, 'curriculum_steps': [0, 0, 0, 0]}
return getattr(models, name)(**parameters) # getattr() 函数用于返回一个对象属性值。调用models.PolyRegression。
# getattr(object, name[, default]) -> value
# 从对象获取命名属性;getattr(x,'y')等价于x.y。
# 当给定默认参数时,当属性不存在时返回该参数;如果没有该参数,则在这种情况下引发异常。
# (**parameters)应该是将其他参数一起传递
def get_optimizer(self, model_parameters):
return getattr(torch.optim, self.config['optimizer']['name'])(model_parameters,
def get_lr_scheduler(self, optimizer):
return getattr(torch.optim.lr_scheduler,
self.config['lr_scheduler']['name'])(optimizer, **self.config['lr_scheduler']['parameters'])
def get_loss_parameters(self):
return self.config['loss_parameters']
def get_test_parameters(self):
return self.config['test_parameters']
def __getitem__(self, item):
return self.config[item]
import torch
import torch.nn as nn
from torchvision.models import resnet34, resnet50, resnet101
from efficientnet_pytorch import EfficientNet # pytorch中为efficientnet专门写好的网络模型
class OutputLayer(nn.Module): # 这个类就是说如果有额外的输出的话,则在EfficientNet的fc全连接层之后再加一层
def __init__(self, fc, num_extra):
super(OutputLayer, self).__init__() # 调用父方法,初始化
self.regular_outputs_layer = fc # Linear(in_features=1280, out_features=35, bias=True)
self.num_extra = num_extra # 0
if num_extra > 0:
self.extra_outputs_layer = nn.Linear(fc.in_features, num_extra)
def forward(self, x):
regular_outputs = self.regular_outputs_layer(x)
if self.num_extra > 0:
extra_outputs = self.extra_outputs_layer(x)
extra_outputs = None
return regular_outputs, extra_outputs
class PolyRegression(nn.Module): #
def __init__(self,
num_outputs, # 35
backbone, # 'efficientnet-b0'
pretrained, # True
curriculum_steps=None, # [0, 0, 0, 0]
extra_outputs=0, # 0
share_top_y=True, # True
pred_category=False): # False
super(PolyRegression, self).__init__() # 调用父方法
if 'efficientnet' in backbone:
if pretrained:
self.model = EfficientNet.from_pretrained(backbone, num_classes=num_outputs) # 第一次用这里会下载预训练好的模型。加载预训练好的EfficientNet
self.model = EfficientNet.from_name(backbone, override_params={'num_classes': num_outputs}) # 加载EfficientNet(只是网络结构,无预训练参数)
self.model._fc = OutputLayer(self.model._fc, extra_outputs) # 修改EfficientNet的全连接层。调用OutputLayer类。(_fc): Linear(in_features=1280, out_features=35, bias=True)
elif backbone == 'resnet34':
self.model = resnet34(pretrained=pretrained)
self.model.fc = nn.Linear(self.model.fc.in_features, num_outputs)
self.model.fc = OutputLayer(self.model.fc, extra_outputs)
elif backbone == 'resnet50':
self.model = resnet50(pretrained=pretrained)
self.model.fc = nn.Linear(self.model.fc.in_features, num_outputs)
self.model.fc = OutputLayer(self.model.fc, extra_outputs)
elif backbone == 'resnet101':
self.model = resnet101(pretrained=pretrained)
self.model.fc = nn.Linear(self.model.fc.in_features, num_outputs)
self.model.fc = OutputLayer(self.model.fc, extra_outputs)
raise NotImplementedError()
self.curriculum_steps = [0, 0, 0, 0] if curriculum_steps is None else curriculum_steps # [0, 0, 0, 0]。这个值目前不知道干啥用。
self.share_top_y = share_top_y # True
self.extra_outputs = extra_outputs # 0
self.pred_category = pred_category # False
self.sigmoid = nn.Sigmoid() # 这里的几步相当于是承接一下传进来的参数
def forward(self, x, epoch=None, **kwargs):
output, extra_outputs = self.model(x, **kwargs)
for i in range(len(self.curriculum_steps)):
if epoch is not None and epoch < self.curriculum_steps[i]:
output[-len(self.curriculum_steps) + i] = 0
return output, extra_outputs
def decode(self, all_outputs, labels, conf_threshold=0.5):
outputs, extra_outputs = all_outputs # outputs:torch.Size([16, 35])
if extra_outputs is not None: # extra_outputs=None
extra_outputs = extra_outputs.reshape(labels.shape[0], 5, -1)
extra_outputs = extra_outputs.argmax(dim=2)
outputs = outputs.reshape(len(outputs), -1, 7) # score + upper + lower + 4 coeffs = 7 # torch.Size([16, 5, 7])
outputs[:, :, 0] = self.sigmoid(outputs[:, :, 0]) # 得分经过sigmoid函数
outputs[outputs[:, :, 0] < conf_threshold] = 0 # conf_threshold=0.5
if False and self.share_top_y:
outputs[:, :, 0] = outputs[:, 0, 0].expand(outputs.shape[0], outputs.shape[1])
return outputs, extra_outputs
def loss(self,
outputs, # {16 35}
target, # {16 5 115}
conf_weight=1, # 1
lower_weight=1, # 1
upper_weight=1, # 1
cls_weight=1, # 0。不知为何要设为0
poly_weight=300, # 300
threshold=15 / 720.): # 0.020833333333333332 #不太懂这是哪个阈值。根据作者的回答,这对应的是等式(5)中的tloss。
pred, extra_outputs = outputs # extra_outputs=None {16 35}
bce = nn.BCELoss() # 实例化损失函数
mse = nn.MSELoss()
s = nn.Sigmoid() # 激活函数
threshold = nn.Threshold(threshold**2, 0.) # nn.Threshold:https://samuel92.blog.csdn.net/article/details/105888269
# Threshold(threshold=0.00043402777777777775, value=0.0)
# 大于threshold还是原值,否则是value
pred = pred.reshape(-1, target.shape[1], 1 + 2 + 4) # torch.Size([16, 5, 7])。
# 16张图片,每张图片5条车道线,每条车道线7个要预测的值
# 刚刚发现Debug窗口中变量还有shape这个属性。
target_categories, pred_confs = target[:, :, 0].reshape((-1, 1)), s(pred[:, :, 0]).reshape((-1, 1)) # 第一第二维取所有的,第三维取[0]。取出的是预测的车道线的类别(认为是不是车道线,认为是的话则1,否则0。当时是根据车道线的数量标记的。
# target_categories:torch.Size([80, 1])。真实类别(是否有车道线)
# pred_confs:torch.Size([80, 1])。预测的置信度。网络输出值要经过一个sigmoid函数(为什么文中没有交代,但是对的吧)
target_uppers, pred_uppers = target[:, :, 2].reshape((-1, 1)), pred[:, :, 2].reshape((-1, 1)) # 上限
# target_uppers:torch.Size([80, 1])。pred_uppers:torch.Size([80, 1])
target_points, pred_polys = target[:, :, 3:].reshape((-1, target.shape[2] - 3)), pred[:, :, 3:].reshape(-1, 4) # 真实车道点坐标和预测多项式系数
# target_points:torch.Size([80, 112])pred_polys:torch.Size([80, 4])
target_lowers, pred_lowers = target[:, :, 1], pred[:, :, 1] # 下限
# target_lowers:torch.Size([16, 5]) pred_lowers:torch.Size([16, 5])
if self.share_top_y:
# inexistent lanes have -1e-5 as lower
# i'm just setting it to a high value here so that the .min below works fine
target_lowers[target_lowers < 0] = 1 # 上面英语是解释
target_lowers[...] = target_lowers.min(dim=1, keepdim=True)[0] # 执行本步之前target_lowers:tensor([[0.3333, 0.3472, 0.3333, 0.3333, 1.0000],
[0.3889, 0.3750, 0.3889, 1.0000, 1.0000],
[0.3194, 0.3333, 0.3472, 1.0000, 1.0000],
[0.3750, 0.3750, 0.3611, 1.0000, 1.0000],
[0.3889, 0.3889, 0.3750, 1.0000, 1.0000],
[0.3194, 0.3194, 0.3472, 1.0000, 1.0000],
[0.3750, 0.3750, 0.3750, 1.0000, 1.0000],
[0.3333, 0.3333, 0.3472, 1.0000, 1.0000],
[0.4028, 0.4028, 0.4028, 0.4028, 1.0000],
[0.3889, 0.3889, 0.4028, 0.3750, 1.0000],
[0.3611, 0.3611, 0.3472, 0.3611, 1.0000],
[0.3611, 0.3472, 0.3611, 0.3611, 1.0000],
[0.3750, 0.3750, 0.3750, 0.3750, 1.0000],
[0.3611, 0.3472, 0.3611, 0.3611, 1.0000],
[0.3472, 0.3611, 0.3472, 0.3472, 1.0000],
[0.3472, 0.3611, 0.3611, 1.0000, 1.0000]], device='cuda:0')
# 执行之后:
tensor([[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],
[0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
[0.3194, 0.3194, 0.3194, 0.3194, 0.3194],
[0.3611, 0.3611, 0.3611, 0.3611, 0.3611],
[0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
[0.3194, 0.3194, 0.3194, 0.3194, 0.3194],
[0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
[0.3333, 0.3333, 0.3333, 0.3333, 0.3333],
[0.4028, 0.4028, 0.4028, 0.4028, 0.4028],
[0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
[0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
[0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
[0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
[0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
[0.3472, 0.3472, 0.3472, 0.3472, 0.3472],
[0.3472, 0.3472, 0.3472, 0.3472, 0.3472]], device='cuda:0')
# 相当于把每张图片中(真实)最低的车道线下限找出来,赋值给该图片其他车道线
pred_lowers[...] = pred_lowers[:, 0].reshape(-1, 1).expand(pred.shape[0], pred.shape[1])
# 相当于是把预测出的下限(第一条赋值给其他)
# 执行前:
# tensor([[ 0.3535, 0.0249, 0.0093, -0.0395, 0.0213],
[ 0.3805, 0.0134, 0.0160, -0.0178, 0.0140],
[ 0.3622, 0.0081, 0.0220, -0.0137, 0.0179],
[ 0.3494, 0.0137, 0.0083, -0.0081, -0.0029],
[ 0.3700, 0.0095, 0.0172, -0.0104, 0.0139],
[ 0.3412, 0.0227, 0.0132, -0.0182, 0.0284],
[ 0.3597, 0.0178, 0.0105, -0.0178, 0.0143],
[ 0.3421, 0.0309, -0.0078, -0.0274, 0.0260],
[ 0.3914, 0.0276, 0.0022, -0.0368, 0.0225],
[ 0.3634, 0.0131, 0.0176, -0.0090, 0.0079],
[ 0.3482, 0.0251, 0.0068, -0.0437, 0.0261],
[ 0.3638, 0.0144, 0.0193, -0.0222, 0.0148],
[ 0.3742, 0.0059, 0.0230, -0.0074, 0.0050],
[ 0.3800, 0.0265, 0.0094, -0.0416, 0.0269],
[ 0.3521, 0.0156, 0.0247, -0.0218, 0.0224],
[ 0.3536, 0.0077, 0.0365, -0.0382, 0.0281]], device='cuda:0')
# 执行后
# tensor([[0.3535, 0.3535, 0.3535, 0.3535, 0.3535],
[0.3805, 0.3805, 0.3805, 0.3805, 0.3805],
[0.3622, 0.3622, 0.3622, 0.3622, 0.3622],
[0.3494, 0.3494, 0.3494, 0.3494, 0.3494],
[0.3700, 0.3700, 0.3700, 0.3700, 0.3700],
[0.3412, 0.3412, 0.3412, 0.3412, 0.3412],
[0.3597, 0.3597, 0.3597, 0.3597, 0.3597],
[0.3421, 0.3421, 0.3421, 0.3421, 0.3421],
[0.3914, 0.3914, 0.3914, 0.3914, 0.3914],
[0.3634, 0.3634, 0.3634, 0.3634, 0.3634],
[0.3482, 0.3482, 0.3482, 0.3482, 0.3482],
[0.3638, 0.3638, 0.3638, 0.3638, 0.3638],
[0.3742, 0.3742, 0.3742, 0.3742, 0.3742],
[0.3800, 0.3800, 0.3800, 0.3800, 0.3800],
[0.3521, 0.3521, 0.3521, 0.3521, 0.3521],
[0.3536, 0.3536, 0.3536, 0.3536, 0.3536]], device='cuda:0')
target_lowers = target_lowers.reshape((-1, 1)) # torch.Size([80, 1])
pred_lowers = pred_lowers.reshape((-1, 1)) # torch.Size([80, 1])
target_confs = (target_categories > 0).float()
valid_lanes_idx = target_confs == 1 # 记录有效的车道线的位置
valid_lanes_idx_flat = valid_lanes_idx.reshape(-1)
lower_loss = mse(target_lowers[valid_lanes_idx], pred_lowers[valid_lanes_idx]) # 计算mse误差(上限)
upper_loss = mse(target_uppers[valid_lanes_idx], pred_uppers[valid_lanes_idx]) # 下限
# classification loss
if self.pred_category and self.extra_outputs > 0: # 预测类别,额外输出有时才进入。(这显然不满足)
ce = nn.CrossEntropyLoss()
pred_categories = extra_outputs.reshape(target.shape[0] * target.shape[1], -1)
target_categories = target_categories.reshape(pred_categories.shape[:-1]).long()
pred_categories = pred_categories[target_categories > 0]
target_categories = target_categories[target_categories > 0]
cls_loss = ce(pred_categories, target_categories - 1)
cls_loss = 0 # 分类损失为0
# poly loss calc
target_xs = target_points[valid_lanes_idx_flat, :target_points.shape[1] // 2] # 举例如下
# target_points:torch.Size([80, 112])
tensor([[ 3.8984e-01, 3.7813e-01, 3.6719e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 4.1328e-01, 4.1484e-01, 4.1641e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 4.3203e-01, 4.4375e-01, 4.5547e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 4.0859e-01, 4.3906e-01, 4.7813e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[-1.0000e+05, -1.0000e+05, -1.0000e+05, ..., -1.0000e+05, 1.0000e+05, -1.0000e+05],
[-1.0000e+05, -1.0000e+05, -1.0000e+05, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05]], device='cuda:0')
# valid_lanes_idx_flat:True, True, True,……,True, False, False
# target_xs:torch.Size([56, 56])
# 第一个56是取的valid_lanes_idx_flat为True的情况;第二个56是取的target_points中前56个点(x)
tensor([[ 3.8984e-01, 3.7813e-01, 3.6719e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 4.1328e-01, 4.1484e-01, 4.1641e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 4.3203e-01, 4.4375e-01, 4.5547e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 3.5156e-01, 3.5547e-01, 3.5547e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 3.8438e-01, 4.0078e-01, 4.1563e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05],
[ 4.0859e-01, 4.3906e-01, 4.7813e-01, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05]], device='cuda:0')
# target_points[0, :target_points.shape[1] // 2]:
tensor([ 3.8984e-01, 3.7813e-01, 3.6719e-01, 3.5391e-01, 3.3984e-01, 3.2656e-01, 3.1250e-01, 2.9219e-01, 2.7031e-01, 2.4844e-01, 2.2656e-01, 2.0469e-01, 1.8359e-01, 1.6172e-01, 1.3984e-01, 1.1797e-01, 9.6094e-02, 7.5000e-02, 5.3125e-02, 3.1250e-02, 9.3750e-03, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05], device='cuda:0')
# target_points[0, :]
tensor([ 3.8984e-01, 3.7813e-01, 3.6719e-01, 3.5391e-01, 3.3984e-01, 3.2656e-01, 3.1250e-01, 2.9219e-01, 2.7031e-01, 2.4844e-01, 2.2656e-01, 2.0469e-01, 1.8359e-01, 1.6172e-01, 1.3984e-01, 1.1797e-01, 9.6094e-02, 7.5000e-02, 5.3125e-02, 3.1250e-02, 9.3750e-03, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, 3.3333e-01, 3.4722e-01, 3.6111e-01, 3.7500e-01, 3.8889e-01, 4.0278e-01, 4.1667e-01, 4.3056e-01, 4.4444e-01, 4.5833e-01, 4.7222e-01, 4.8611e-01, 5.0000e-01, 5.1389e-01, 5.2778e-01, 5.4167e-01, 5.5556e-01, 5.6944e-01, 5.8333e-01, 5.9722e-01, 6.1111e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05], device='cuda:0')
ys = target_points[valid_lanes_idx_flat, target_points.shape[1] // 2:].t()
# 这里取的是y的转置(为了下面的计算)
# tensor([[ 3.3333e-01, 3.4722e-01, 3.3333e-01, ..., 3.4722e-01, 3.6111e-01, 3.6111e-01], [ 3.4722e-01, 3.6111e-01, 3.4722e-01, ..., 3.6111e-01, 3.7500e-01, 3.7500e-01], [ 3.6111e-01, 3.7500e-01, 3.6111e-01, ..., 3.7500e-01, 3.8889e-01, 3.8889e-01], ..., [-1.0000e+05, -1.0000e+05, -1.0000e+05, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05], [-1.0000e+05, -1.0000e+05, -1.0000e+05, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05], [-1.0000e+05, -1.0000e+05, -1.0000e+05, ..., -1.0000e+05, -1.0000e+05, -1.0000e+05]], device='cuda:0')
valid_xs = target_xs >= 0 # 有效的x点
pred_polys = pred_polys[valid_lanes_idx_flat] # 有效的预测多项式
pred_xs = pred_polys[:, 0] * ys**3 + pred_polys[:, 1] * ys**2 + pred_polys[:, 2] * ys + pred_polys[:, 3] # 求预测的x坐标
# 一个例子如下:tensor([[ 3.6480e-01, 4.1295e-01, 4.3892e-01, ..., 3.8540e-01,
4.1487e-01, 4.9135e-01],
[ 3.5220e-01, 4.1199e-01, 4.5362e-01, ..., 3.6899e-01,
4.1176e-01, 5.0227e-01],
[ 3.3908e-01, 4.1083e-01, 4.6825e-01, ..., 3.5225e-01,
4.0856e-01, 5.1314e-01],
[ 2.1743e+14, -7.2232e+13, 1.3829e+14, ..., -1.0635e+14,
-1.2338e+14, 2.1186e+14],
[ 2.1743e+14, -7.2232e+13, 1.3829e+14, ..., -1.0635e+14,
-1.2338e+14, 2.1186e+14],
[ 2.1743e+14, -7.2232e+13, 1.3829e+14, ..., -1.0635e+14,
-1.2338e+14, 2.1186e+14]], device='cuda:0')
pred_xs.t_() # 转置
# tensor([[ 3.6480e-01, 3.5220e-01, 3.3908e-01, ..., 2.1743e+14,
2.1743e+14, 2.1743e+14],
[ 4.1295e-01, 4.1199e-01, 4.1083e-01, ..., -7.2232e+13,
-7.2232e+13, -7.2232e+13],
[ 4.3892e-01, 4.5362e-01, 4.6825e-01, ..., 1.3829e+14,
1.3829e+14, 1.3829e+14],
[ 3.8540e-01, 3.6899e-01, 3.5225e-01, ..., -1.0635e+14,
-1.0635e+14, -1.0635e+14],
[ 4.1487e-01, 4.1176e-01, 4.0856e-01, ..., -1.2338e+14,
-1.2338e+14, -1.2338e+14],
[ 4.9135e-01, 5.0227e-01, 5.1314e-01, ..., 2.1186e+14,
2.1186e+14, 2.1186e+14]], device='cuda:0')
weights = (torch.sum(valid_xs, dtype=torch.float32) / torch.sum(valid_xs, dim=1, dtype=torch.float32))**0.5 # https://blog.csdn.net/qq_39463274/article/details/105145029.
# dim=1是按行求
# 下面英文注释是解释
pred_xs = (pred_xs.t_() *
weights).t() # without this, lanes with more points would have more weight on the cost function
# tensor([[ 3.3652e+00, 3.2490e+00, 3.1279e+00, ..., 2.0057e+15,
2.0057e+15, 2.0057e+15],
[ 2.5463e+00, 2.5404e+00, 2.5332e+00, ..., -4.4539e+14,
-4.4539e+14, -4.4539e+14],
[ 2.7972e+00, 2.8908e+00, 2.9841e+00, ..., 8.8132e+14,
8.8132e+14, 8.8132e+14],
[ 3.5552e+00, 3.4039e+00, 3.2494e+00, ..., -9.8106e+14,
-9.8106e+14, -9.8106e+14],
[ 2.5858e+00, 2.5664e+00, 2.5465e+00, ..., -7.6901e+14,
-7.6901e+14, -7.6901e+14],
[ 3.0963e+00, 3.1651e+00, 3.2336e+00, ..., 1.3351e+15,
1.3351e+15, 1.3351e+15]], device='cuda:0')
target_xs = (target_xs.t_() * weights).t() #
poly_loss = mse(pred_xs[valid_xs], target_xs[valid_xs]) / valid_lanes_idx.sum() # tensor(0.0124, device='cuda:0')
poly_loss = threshold(
(pred_xs[valid_xs] - target_xs[valid_xs])**2).sum() / (valid_lanes_idx.sum() * valid_xs.sum())
# applying weights to partial losses
poly_loss = poly_loss * poly_weight
lower_loss = lower_loss * lower_weight
upper_loss = upper_loss * upper_weight
cls_loss = cls_loss * cls_weight
conf_loss = bce(pred_confs, target_confs) * conf_weight
loss = conf_loss + lower_loss + upper_loss + poly_loss + cls_loss
return loss, {
'conf': conf_loss,
'lower': lower_loss,
'upper': upper_loss,
'poly': poly_loss,
'cls_loss': cls_loss
} # 计算损失,文中有交代
import cv2
import numpy as np
import imgaug.augmenters as iaa # 数据增强的库
from imgaug.augmenters import Resize
from torchvision.transforms import ToTensor
from torch.utils.data.dataset import Dataset
from imgaug.augmentables.lines import LineString, LineStringsOnImage
from .elas import ELAS
from .llamas import LLAMAS
from .tusimple import TuSimple
from .nolabel_dataset import NoLabelDataset
GT_COLOR = (255, 0, 0)
PRED_HIT_COLOR = (0, 255, 0)
PRED_MISS_COLOR = (0, 0, 255)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])
class LaneDataset(Dataset):
def __init__(self,
dataset='tusimple', # 'tusimple'
augmentations=None, # []
normalize=False, # True
split='train', # 'val'
img_size=(360, 640), # [360, 640]
aug_chance=1., # 1.0
**kwargs): # {'max_lanes': 5, 'root': '/home/wqf/PolyLaneNet/PolyLaneNet-master'}
super(LaneDataset, self).__init__() # 调用父方法初始化
if dataset == 'tusimple':
self.dataset = TuSimple(split=split, **kwargs)
elif dataset == 'llamas':
self.dataset = LLAMAS(split=split, **kwargs)
elif dataset == 'elas':
self.dataset = ELAS(split=split, **kwargs)
elif dataset == 'nolabel_dataset':
self.dataset = NoLabelDataset(**kwargs)
raise NotImplementedError()
self.transform_annotations() # 调用自己的方法
self.img_h, self.img_w = img_size # [360, 640]
if augmentations is not None: # 这里的augmentations=[]。因此会进入if语句,但不会执行任何东西。
# add augmentations
augmentations = [getattr(iaa, aug['name'])(**aug['parameters'])
for aug in augmentations] # add augmentation
self.normalize = normalize # True
transformations = iaa.Sequential([Resize({'height': self.img_h, 'width': self.img_w})]) # 创建一个增强器。Resize到固定尺寸。
self.to_tensor = ToTensor() # 实例化一个类
self.transform = iaa.Sequential([iaa.Sometimes(then_list=augmentations, p=aug_chance), transformations]) # https://blog.csdn.net/u012897374/article/details/80142744。这里augmentations=[];aug_chance=1;相当于没有进行数据增强
self.max_lanes = self.dataset.max_lanes
def transform_annotation(self, anno, img_wh=None): # 这个函数好像是对标注信息进行转换。即由于图片尺寸的变化,标注点的坐标也要跟着转换
# 一个anno例子如下:
# {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492626287507231547/20.jpg', 'org_path': 'clips/0531/1492626287507231547/20.jpg', 'org_lanes': [[-2, -2, -2, -2, -2, -2, -2, -2, 499, 484, 470, 453, 435, 418, 400, 374, 346, 318, 290, 262, 235, 207, 179, 151, 123, 96, 68, 40, 12, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, -2, 529, 531, 533, 536, 538, 540, 540, 538, 536, 531, 525, 519, 513, 507, 499, 491, 483, 475, 467, 459, 451, 443, 435, 426, 418, 410, 402, 394, 386, 378, 370, 362, 354, 346, 338, 330, 322, 314, 306, 297, 289, 281, 273, 265, 257, 249, 241], [-2, -2, -2, -2, -2, -2, -2, -2, 553, 568, 583, 598, 613, 640, 667, 693, 719, 740, 761, 783, 804, 825, 846, 868, 883, 897, 912, 926, 941, 955, 969, 984, 998, 1013, 1027, 1042, 1056, 1070, 1085, 1099, 1114, 1128, 1143, 1157, 1171, 1186, 1200, 1215, 1229, 1244, 1258, 1272, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, 558, 585, 613, 646, 679, 714, 770, 817, 865, 912, 954, 994, 1033, 1073, 1113, 1153, 1193, 1232, 1272, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], 'lanes': [[(499, 240), (484, 250), (470, 260), (453, 270), (435, 280), (418, 290), (400, 300), (374, 310), (346, 320), (318, 330), (290, 340), (262, 350), (235, 360), (207, 370), (179, 380), (151, 390), (123, 400), (96, 410), (68, 420), (40, 430), (12, 440)], [(529, 250), (531, 260), (533, 270), (536, 280), (538, 290), (540, 300), (540, 310), (538, 320), (536, 330), (531, 340), (525, 350), (519, 360), (513, 370), (507, 380), (499, 390), (491, 400), (483, 410), (475, 420), (467, 430), (459, 440), (451, 450), (443, 460), (435, 470), (426, 480), (418, 490), (410, 500), (402, 510), (394, 520), (386, 530), (378, 540), (370, 550), (362, 560), (354, 570), (346, 580), (338, 590), (330, 600), (322, 610), (314, 620), (306, 630), (297, 640), (289, 650), (281, 660), (273, 670), (265, 680), (257, 690), (249, 700), (241, 710)], [(553, 240), (568, 250), (583, 260), (598, 270), (613, 280), (640, 290), (667, 300), (693, 310), (719, 320), (740, 330), (761, 340), (783, 350), (804, 360), (825, 370), (846, 380), (868, 390), (883, 400), (897, 410), (912, 420), (926, 430), (941, 440), (955, 450), (969, 460), (984, 470), (998, 480), (1013, 490), (1027, 500), (1042, 510), (1056, 520), (1070, 530), (1085, 540), (1099, 550), (1114, 560), (1128, 570), (1143, 580), (1157, 590), (1171, 600), (1186, 610), (1200, 620), (1215, 630), (1229, 640), (1244, 650), (1258, 660), (1272, 670)], [(558, 240), (585, 250), (613, 260), (646, 270), (679, 280), (714, 290), (770, 300), (817, 310), (865, 320), (912, 330), (954, 340), (994, 350), (1033, 360), (1073, 370), (1113, 380), (1153, 390), (1193, 400), (1232, 410), (1272, 420)]], 'aug': False, 'y_samples': [160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710]}
if img_wh is None:
img_h = self.dataset.get_img_heigth(anno['path']) # 720
img_w = self.dataset.get_img_width(anno['path']) # 1280
img_w, img_h = img_wh
old_lanes = anno['lanes'] # 有效车道线坐标
categories = anno['categories'] if 'categories' in anno else [1] * len(old_lanes) # 标注的是车道线的类别。
old_lanes = zip(old_lanes, categories) # 一个zip object
old_lanes = filter(lambda x: len(x[0]) > 0, old_lanes) # https://www.runoob.com/python3/python3-func-filter.html。filter函数返回一个迭代器对象。filter object
lanes = np.ones((self.dataset.max_lanes, 1 + 2 + 2 * self.dataset.max_points), dtype=np.float32) * -1e5 # 对于第一个anno。self.dataset.max_lanes=5;self.dataset.max_points=56,所以lanes={ndarray:{5,115}}。值全是-1e5。lanes记录的是转换后的车道线点坐标。
# 2*self.dataset.max_points):x、y坐标
# 2:车道线上限、下限
# 1:图像中这个位置有车道线(1)
lanes[:, 0] = 0 # 所有行,第一列置为0
old_lanes = sorted(old_lanes, key=lambda x: x[0][0][0]) # 内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。
for lane_pos, (lane, category) in enumerate(old_lanes):
lower, upper = lane[0][1], lane[-1][1] # 最低、最高纵坐标
xs = np.array([p[0] for p in lane]) / img_w # x坐标归一化。[0.38984375 0.378125 0.3671875 0.35390625 0.33984375 0.3265625, 0.3125 0.2921875 0.2703125 0.2484375 0.2265625 0.2046875, 0.18359375 0.16171875 0.13984375 0.11796875 0.09609375 0.075, 0.053125 0.03125 0.009375 ]
ys = np.array([p[1] for p in lane]) / img_h # y坐标归一化。[0.33333333 0.34722222 0.36111111 0.375 0.38888889 0.40277778, 0.41666667 0.43055556 0.44444444 0.45833333 0.47222222 0.48611111, 0.5 0.51388889 0.52777778 0.54166667 0.55555556 0.56944444, 0.58333333 0.59722222 0.61111111]
lanes[lane_pos, 0] = category # 第0列存储单条类别(为1即代表有车道线)
lanes[lane_pos, 1] = lower / img_h # 第1列存储车道线下限
lanes[lane_pos, 2] = upper / img_h # 第2列存储车道线上限
lanes[lane_pos, 3:3 + len(xs)] = xs # 第3列到self.dataset.max_points之间存储归一化后的x坐标(如果没有这么长则保持为-1e5
lanes[lane_pos, (3 + self.dataset.max_points):(3 + self.dataset.max_points + len(ys))] = ys # 第(3 + self.dataset.max_points)列到后面
new_anno = { # 新的标注信息
'path': anno['path'],
'label': lanes,
'old_anno': anno,
'categories': [cat for _, cat in old_lanes] # 一个例子:'categories': [1, 1, 1, 1]
return new_anno
def annotations(self):
return self.dataset.annotations
def transform_annotations(self):
print('Transforming annotations...')
self.dataset.annotations = np.array(list(map(self.transform_annotation, self.dataset.annotations))) # 关于map函数的用法。https://www.runoob.com/python/python-func-map.html。直接看例子就行。相当于把self.dataset.annotations逐个传入self.transform_annotation函数。只用map则生成一个迭代器,list(map())则生成列表。
def draw_annotation(self, idx, pred=None, img=None, cls_pred=None):
if img is None:
img, label, _ = self.__getitem__(idx, transform=True)
# Tensor to opencv image
img = img.permute(1, 2, 0).numpy()
# Unnormalize
if self.normalize:
img = img * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
img = (img * 255).astype(np.uint8)
_, label, _ = self.__getitem__(idx)
img_h, img_w, _ = img.shape
# Draw label
for i, lane in enumerate(label):
if lane[0] == 0: # Skip invalid lanes
lane = lane[3:] # remove conf, upper and lower positions
xs = lane[:len(lane) // 2]
ys = lane[len(lane) // 2:]
ys = ys[xs >= 0]
xs = xs[xs >= 0]
# draw GT points
for p in zip(xs, ys):
p = (int(p[0] * img_w), int(p[1] * img_h))
img = cv2.circle(img, p, 5, color=GT_COLOR, thickness=-1)
# draw GT lane ID
str(i), (int(xs[0] * img_w), int(ys[0] * img_h)),
color=(0, 255, 0))
if pred is None:
return img
# Draw predictions
pred = pred[pred[:, 0] != 0] # filter invalid lanes
matches, accs, _ = self.dataset.get_metrics(pred, idx)
overlay = img.copy()
for i, lane in enumerate(pred):
if matches[i]:
lane = lane[1:] # remove conf
lower, upper = lane[0], lane[1]
lane = lane[2:] # remove upper, lower positions
# generate points from the polynomial
ys = np.linspace(lower, upper, num=100)
points = np.zeros((len(ys), 2), dtype=np.int32)
points[:, 1] = (ys * img_h).astype(int)
points[:, 0] = (np.polyval(lane, ys) * img_w).astype(int)
points = points[(points[:, 0] > 0) & (points[:, 0] < img_w)]
# draw lane with a polyline on the overlay
for current_point, next_point in zip(points[:-1], points[1:]):
overlay = cv2.line(overlay, tuple(current_point), tuple(next_point), color=color, thickness=2)
# draw class icon
if cls_pred is not None and len(points) > 0:
class_icon = self.dataset.get_class_icon(cls_pred[i])
class_icon = cv2.resize(class_icon, (32, 32))
mid = tuple(points[len(points) // 2] - 60)
x, y = mid
img[y:y + class_icon.shape[0], x:x + class_icon.shape[1]] = class_icon
# draw lane ID
if len(points) > 0:
cv2.putText(img, str(i), tuple(points[0]), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=color)
# draw lane accuracy
if len(points) > 0:
'{:.2f}'.format(accs[i] * 100),
tuple(points[len(points) // 2] - 30),
# Add lanes overlay
w = 0.6
img = ((1. - w) * img + w * overlay).astype(np.uint8)
return img
def lane_to_linestrings(self, lanes):
lines = []
for lane in lanes:
return lines
def linestrings_to_lanes(self, lines):
lanes = []
for line in lines:
return lanes
def __getitem__(self, idx, transform=True): # self:
# idx:112
# self.dataset:{Tusimple:358}.
# 这个方法相当于去取值
# 至于idx,则是经常这么用的。看的代码还是少吧
item = self.dataset[idx] # 按照这种方法去取值,就会调用Tusimple类中__getitem__方法。
# 一个item例子如下:
# {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492630158697138522/20.jpg', 'label': array([[ 1.0000000e+00, 2.7777779e-01, 9.7222221e-01, 5.3125000e-01,
5.2265626e-01, 5.1406252e-01, 5.1093751e-01, 5.1015627e-01,
5.0781250e-01, 5.0078124e-01, 4.9453124e-01, 4.8828125e-01,
4.8203126e-01, 4.7499999e-01, 4.6875000e-01, 4.6250001e-01,
4.5625001e-01, 4.4921875e-01, 4.4296876e-01, 4.3671876e-01,
4.2968750e-01, 4.2343751e-01, 4.1718751e-01, 4.1093749e-01,
4.0390626e-01, 3.9765626e-01, 3.9140624e-01, 3.8515624e-01,
3.7812501e-01, 3.7187499e-01, 3.6562499e-01, 3.5937500e-01,
3.5234374e-01, 3.4609374e-01, 3.3984375e-01, 3.3359376e-01,
3.2656249e-01, 3.2031250e-01, 3.1406251e-01, 3.0781251e-01,
3.0078125e-01, 2.9453126e-01, 2.8828126e-01, 2.8203124e-01,
2.7500001e-01, 2.6875001e-01, 2.6249999e-01, 2.5546876e-01,
2.4921875e-01, 2.4296875e-01, 2.3671874e-01, 2.2968750e-01,
2.2343750e-01, 2.1718749e-01, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, 2.7777779e-01,
2.9166666e-01, 3.0555555e-01, 3.1944445e-01, 3.3333334e-01,
3.4722221e-01, 3.6111110e-01, 3.7500000e-01, 3.8888890e-01,
4.0277779e-01, 4.1666666e-01, 4.3055555e-01, 4.4444445e-01,
4.5833334e-01, 4.7222221e-01, 4.8611110e-01, 5.0000000e-01,
5.1388890e-01, 5.2777779e-01, 5.4166669e-01, 5.5555558e-01,
5.6944442e-01, 5.8333331e-01, 5.9722221e-01, 6.1111110e-01,
6.2500000e-01, 6.3888890e-01, 6.5277779e-01, 6.6666669e-01,
6.8055558e-01, 6.9444442e-01, 7.0833331e-01, 7.2222221e-01,
7.3611110e-01, 7.5000000e-01, 7.6388890e-01, 7.7777779e-01,
7.9166669e-01, 8.0555558e-01, 8.1944442e-01, 8.3333331e-01,
8.4722221e-01, 8.6111110e-01, 8.7500000e-01, 8.8888890e-01,
9.0277779e-01, 9.1666669e-01, 9.3055558e-01, 9.4444442e-01,
9.5833331e-01, 9.7222221e-01, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
[ 1.0000000e+00, 2.7777779e-01, 9.1666669e-01, 5.4140627e-01,
5.3671873e-01, 5.3281248e-01, 5.3046876e-01, 5.3437501e-01,
5.4062498e-01, 5.4765624e-01, 5.5859375e-01, 5.6953126e-01,
5.8125001e-01, 5.9218752e-01, 6.0312498e-01, 6.1406249e-01,
6.2578124e-01, 6.3671875e-01, 6.4765626e-01, 6.5859377e-01,
6.7031252e-01, 6.8124998e-01, 6.9218749e-01, 7.0390624e-01,
7.1484375e-01, 7.2578126e-01, 7.3671877e-01, 7.4843752e-01,
7.5937498e-01, 7.7031249e-01, 7.8203124e-01, 7.9296875e-01,
8.0390626e-01, 8.1484377e-01, 8.2656252e-01, 8.3749998e-01,
8.4843749e-01, 8.6015624e-01, 8.7109375e-01, 8.8203126e-01,
8.9296877e-01, 9.0468752e-01, 9.1562498e-01, 9.2656249e-01,
9.3828124e-01, 9.4921875e-01, 9.6015626e-01, 9.7109377e-01,
9.8281252e-01, 9.9374998e-01, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, 2.7777779e-01,
2.9166666e-01, 3.0555555e-01, 3.1944445e-01, 3.3333334e-01,
3.4722221e-01, 3.6111110e-01, 3.7500000e-01, 3.8888890e-01,
4.0277779e-01, 4.1666666e-01, 4.3055555e-01, 4.4444445e-01,
4.5833334e-01, 4.7222221e-01, 4.8611110e-01, 5.0000000e-01,
5.1388890e-01, 5.2777779e-01, 5.4166669e-01, 5.5555558e-01,
5.6944442e-01, 5.8333331e-01, 5.9722221e-01, 6.1111110e-01,
6.2500000e-01, 6.3888890e-01, 6.5277779e-01, 6.6666669e-01,
6.8055558e-01, 6.9444442e-01, 7.0833331e-01, 7.2222221e-01,
7.3611110e-01, 7.5000000e-01, 7.6388890e-01, 7.7777779e-01,
7.9166669e-01, 8.0555558e-01, 8.1944442e-01, 8.3333331e-01,
8.4722221e-01, 8.6111110e-01, 8.7500000e-01, 8.8888890e-01,
9.0277779e-01, 9.1666669e-01, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
[ 1.0000000e+00, 2.7777779e-01, 5.4166669e-01, 5.5546874e-01,
5.5312502e-01, 5.5156249e-01, 5.5078125e-01, 5.6015623e-01,
5.6953126e-01, 5.8906251e-01, 6.1171877e-01, 6.3437498e-01,
6.6640627e-01, 6.9765627e-01, 7.2890627e-01, 7.6093751e-01,
7.9218751e-01, 8.2343751e-01, 8.5468751e-01, 8.8671875e-01,
9.1796875e-01, 9.4921875e-01, 9.8124999e-01, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, 2.7777779e-01,
2.9166666e-01, 3.0555555e-01, 3.1944445e-01, 3.3333334e-01,
3.4722221e-01, 3.6111110e-01, 3.7500000e-01, 3.8888890e-01,
4.0277779e-01, 4.1666666e-01, 4.3055555e-01, 4.4444445e-01,
4.5833334e-01, 4.7222221e-01, 4.8611110e-01, 5.0000000e-01,
5.1388890e-01, 5.2777779e-01, 5.4166669e-01, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
[ 0.0000000e+00, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05],
[ 0.0000000e+00, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05, -1.0000000e+05,
-1.0000000e+05, -1.0000000e+05, -1.0000000e+05]], dtype=float32), 'old_anno': {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492630158697138522/20.jpg', 'org_path': 'clips/0531/1492630158697138522/20.jpg', 'org_lanes': [[-2, -2, -2, -2, 680, 669, 658, 654, 653, 650, 641, 633, 625, 617, 608, 600, 592, 584, 575, 567, 559, 550, 542, 534, 526, 517, 509, 501, 493, 484, 476, 468, 460, 451, 443, 435, 427, 418, 410, 402, 394, 385, 377, 369, 361, 352, 344, 336, 327, 319, 311, 303, 294, 286, 278, -2], [-2, -2, -2, -2, 693, 687, 682, 679, 684, 692, 701, 715, 729, 744, 758, 772, 786, 801, 815, 829, 843, 858, 872, 886, 901, 915, 929, 943, 958, 972, 986, 1001, 1015, 1029, 1043, 1058, 1072, 1086, 1101, 1115, 1129, 1143, 1158, 1172, 1186, 1201, 1215, 1229, 1243, 1258, 1272, -2, -2, -2, -2, -2], [-2, -2, -2, -2, 711, 708, 706, 705, 717, 729, 754, 783, 812, 853, 893, 933, 974, 1014, 1054, 1094, 1135, 1175, 1215, 1256, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], 'lanes': [[(680, 200), (669, 210), (658, 220), (654, 230), (653, 240), (650, 250), (641, 260), (633, 270), (625, 280), (617, 290), (608, 300), (600, 310), (592, 320), (584, 330), (575, 340), (567, 350), (559, 360), (550, 370), (542, 380), (534, 390), (526, 400), (517, 410), (509, 420), (501, 430), (493, 440), (484, 450), (476, 460), (468, 470), (460, 480), (451, 490), (443, 500), (435, 510), (427, 520), (418, 530), (410, 540), (402, 550), (394, 560), (385, 570), (377, 580), (369, 590), (361, 600), (352, 610), (344, 620), (336, 630), (327, 640), (319, 650), (311, 660), (303, 670), (294, 680), (286, 690), (278, 700)], [(693, 200), (687, 210), (682, 220), (679, 230), (684, 240), (692, 250), (701, 260), (715, 270), (729, 280), (744, 290), (758, 300), (772, 310), (786, 320), (801, 330), (815, 340), (829, 350), (843, 360), (858, 370), (872, 380), (886, 390), (901, 400), (915, 410), (929, 420), (943, 430), (958, 440), (972, 450), (986, 460), (1001, 470), (1015, 480), (1029, 490), (1043, 500), (1058, 510), (1072, 520), (1086, 530), (1101, 540), (1115, 550), (1129, 560), (1143, 570), (1158, 580), (1172, 590), (1186, 600), (1201, 610), (1215, 620), (1229, 630), (1243, 640), (1258, 650), (1272, 660)], [(711, 200), (708, 210), (706, 220), (705, 230), (717, 240), (729, 250), (754, 260), (783, 270), (812, 280), (853, 290), (893, 300), (933, 310), (974, 320), (1014, 330), (1054, 340), (1094, 350), (1135, 360), (1175, 370), (1215, 380), (1256, 390)]], 'aug': False, 'y_samples': [160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710]}, 'categories': [1, 1, 1]}
img = cv2.imread(item['path']) # 读取了图片{ndarray:(720,1280,3)}
label = item['label'] # {ndarray:(5,115)}
if transform:
line_strings = self.lane_to_linestrings(item['old_anno']['lanes']) # 传入原始标注的车道线的点。将数据格式由{list:3}转换为{list:3}。不同之处在于每一个元素转换为了Linestring类型。
# 一个例子如下:[LineString([(680.00, 200.00), (669.00, 210.00), (658.00, 220.00), (654.00, 230.00), (653.00, 240.00), (650.00, 250.00), (641.00, 260.00), (633.00, 270.00), (625.00, 280.00), (617.00, 290.00), (608.00, 300.00), (600.00, 310.00), (592.00, 320.00), (584.00, 330.00), (575.00, 340.00), (567.00, 350.00), (559.00, 360.00), (550.00, 370.00), (542.00, 380.00), (534.00, 390.00), (526.00, 400.00), (517.00, 410.00), (509.00, 420.00), (501.00, 430.00), (493.00, 440.00), (484.00, 450.00), (476.00, 460.00), (468.00, 470.00), (460.00, 480.00), (451.00, 490.00), (443.00, 500.00), (435.00, 510.00), (427.00, 520.00), (418.00, 530.00), (410.00, 540.00), (402.00, 550.00), (394.00, 560.00), (385.00, 570.00), (377.00, 580.00), (369.00, 590.00), (361.00, 600.00), (352.00, 610.00), (344.00, 620.00), (336.00, 630.00), (327.00, 640.00), (319.00, 650.00), (311.00, 660.00), (303.00, 670.00), (294.00, 680.00), (286.00, 690.00), (278.00, 700.00)], label=None), LineString([(693.00, 200.00), (687.00, 210.00), (682.00, 220.00), (679.00, 230.00), (684.00, 240.00), (692.00, 250.00), (701.00, 260.00), (715.00, 270.00), (729.00, 280.00), (744.00, 290.00), (758.00, 300.00), (772.00, 310.00), (786.00, 320.00), (801.00, 330.00), (815.00, 340.00), (829.00, 350.00), (843.00, 360.00), (858.00, 370.00), (872.00, 380.00), (886.00, 390.00), (901.00, 400.00), (915.00, 410.00), (929.00, 420.00), (943.00, 430.00), (958.00, 440.00), (972.00, 450.00), (986.00, 460.00), (1001.00, 470.00), (1015.00, 480.00), (1029.00, 490.00), (1043.00, 500.00), (1058.00, 510.00), (1072.00, 520.00), (1086.00, 530.00), (1101.00, 540.00), (1115.00, 550.00), (1129.00, 560.00), (1143.00, 570.00), (1158.00, 580.00), (1172.00, 590.00), (1186.00, 600.00), (1201.00, 610.00), (1215.00, 620.00), (1229.00, 630.00), (1243.00, 640.00), (1258.00, 650.00), (1272.00, 660.00)], label=None), LineString([(711.00, 200.00), (708.00, 210.00), (706.00, 220.00), (705.00, 230.00), (717.00, 240.00), (729.00, 250.00), (754.00, 260.00), (783.00, 270.00), (812.00, 280.00), (853.00, 290.00), (893.00, 300.00), (933.00, 310.00), (974.00, 320.00), (1014.00, 330.00), (1054.00, 340.00), (1094.00, 350.00), (1135.00, 360.00), (1175.00, 370.00), (1215.00, 380.00), (1256.00, 390.00)], label=None)]
line_strings = LineStringsOnImage(line_strings, shape=img.shape) # 再次进行类型转换。
# 类型:{LineStringsOnImage:3}。一个例子如下:
# LineStringsOnImage([LineString([(680.00, 200.00), (669.00, 210.00), (658.00, 220.00), (654.00, 230.00), (653.00, 240.00), (650.00, 250.00), (641.00, 260.00), (633.00, 270.00), (625.00, 280.00), (617.00, 290.00), (608.00, 300.00), (600.00, 310.00), (592.00, 320.00), (584.00, 330.00), (575.00, 340.00), (567.00, 350.00), (559.00, 360.00), (550.00, 370.00), (542.00, 380.00), (534.00, 390.00), (526.00, 400.00), (517.00, 410.00), (509.00, 420.00), (501.00, 430.00), (493.00, 440.00), (484.00, 450.00), (476.00, 460.00), (468.00, 470.00), (460.00, 480.00), (451.00, 490.00), (443.00, 500.00), (435.00, 510.00), (427.00, 520.00), (418.00, 530.00), (410.00, 540.00), (402.00, 550.00), (394.00, 560.00), (385.00, 570.00), (377.00, 580.00), (369.00, 590.00), (361.00, 600.00), (352.00, 610.00), (344.00, 620.00), (336.00, 630.00), (327.00, 640.00), (319.00, 650.00), (311.00, 660.00), (303.00, 670.00), (294.00, 680.00), (286.00, 690.00), (278.00, 700.00)], label=None), LineString([(693.00, 200.00), (687.00, 210.00), (682.00, 220.00), (679.00, 230.00), (684.00, 240.00), (692.00, 250.00), (701.00, 260.00), (715.00, 270.00), (729.00, 280.00), (744.00, 290.00), (758.00, 300.00), (772.00, 310.00), (786.00, 320.00), (801.00, 330.00), (815.00, 340.00), (829.00, 350.00), (843.00, 360.00), (858.00, 370.00), (872.00, 380.00), (886.00, 390.00), (901.00, 400.00), (915.00, 410.00), (929.00, 420.00), (943.00, 430.00), (958.00, 440.00), (972.00, 450.00), (986.00, 460.00), (1001.00, 470.00), (1015.00, 480.00), (1029.00, 490.00), (1043.00, 500.00), (1058.00, 510.00), (1072.00, 520.00), (1086.00, 530.00), (1101.00, 540.00), (1115.00, 550.00), (1129.00, 560.00), (1143.00, 570.00), (1158.00, 580.00), (1172.00, 590.00), (1186.00, 600.00), (1201.00, 610.00), (1215.00, 620.00), (1229.00, 630.00), (1243.00, 640.00), (1258.00, 650.00), (1272.00, 660.00)], label=None), LineString([(711.00, 200.00), (708.00, 210.00), (706.00, 220.00), (705.00, 230.00), (717.00, 240.00), (729.00, 250.00), (754.00, 260.00), (783.00, 270.00), (812.00, 280.00), (853.00, 290.00), (893.00, 300.00), (933.00, 310.00), (974.00, 320.00), (1014.00, 330.00), (1054.00, 340.00), (1094.00, 350.00), (1135.00, 360.00), (1175.00, 370.00), (1215.00, 380.00), (1256.00, 390.00)], label=None)], shape=(720, 1280, 3))
img, line_strings = self.transform(image=img, line_strings=line_strings) # 将点和图像都进行格式转换。
# img:{ndarray:(720,1280,3)} 到 {ndarray:(360,640,3)}
# line_strings:由上述到LineStringsOnImage([LineString([(340.00, 100.00), (334.50, 105.00), (329.00, 110.00), (327.00, 115.00), (326.50, 120.00), (325.00, 125.00), (320.50, 130.00), (316.50, 135.00), (312.50, 140.00), (308.50, 145.00), (304.00, 150.00), (300.00, 155.00), (296.00, 160.00), (292.00, 165.00), (287.50, 170.00), (283.50, 175.00), (279.50, 180.00), (275.00, 185.00), (271.00, 190.00), (267.00, 195.00), (263.00, 200.00), (258.50, 205.00), (254.50, 210.00), (250.50, 215.00), (246.50, 220.00), (242.00, 225.00), (238.00, 230.00), (234.00, 235.00), (230.00, 240.00), (225.50, 245.00), (221.50, 250.00), (217.50, 255.00), (213.50, 260.00), (209.00, 265.00), (205.00, 270.00), (201.00, 275.00), (197.00, 280.00), (192.50, 285.00), (188.50, 290.00), (184.50, 295.00), (180.50, 300.00), (176.00, 305.00), (172.00, 310.00), (168.00, 315.00), (163.50, 320.00), (159.50, 325.00), (155.50, 330.00), (151.50, 335.00), (147.00, 340.00), (143.00, 345.00), (139.00, 350.00)], label=None), LineString([(346.50, 100.00), (343.50, 105.00), (341.00, 110.00), (339.50, 115.00), (342.00, 120.00), (346.00, 125.00), (350.50, 130.00), (357.50, 135.00), (364.50, 140.00), (372.00, 145.00), (379.00, 150.00), (386.00, 155.00), (393.00, 160.00), (400.50, 165.00), (407.50, 170.00), (414.50, 175.00), (421.50, 180.00), (429.00, 185.00), (436.00, 190.00), (443.00, 195.00), (450.50, 200.00), (457.50, 205.00), (464.50, 210.00), (471.50, 215.00), (479.00, 220.00), (486.00, 225.00), (493.00, 230.00), (500.50, 235.00), (507.50, 240.00), (514.50, 245.00), (521.50, 250.00), (529.00, 255.00), (536.00, 260.00), (543.00, 265.00), (550.50, 270.00), (557.50, 275.00), (564.50, 280.00), (571.50, 285.00), (579.00, 290.00), (586.00, 295.00), (593.00, 300.00), (600.50, 305.00), (607.50, 310.00), (614.50, 315.00), (621.50, 320.00), (629.00, 325.00), (636.00, 330.00)], label=None), LineString([(355.50, 100.00), (354.00, 105.00), (353.00, 110.00), (352.50, 115.00), (358.50, 120.00), (364.50, 125.00), (377.00, 130.00), (391.50, 135.00), (406.00, 140.00), (426.50, 145.00), (446.50, 150.00), (466.50, 155.00), (487.00, 160.00), (507.00, 165.00), (527.00, 170.00), (547.00, 175.00), (567.50, 180.00), (587.50, 185.00), (607.50, 190.00), (628.00, 195.00)], label=None)], shape=(360, 640, 3))
line_strings.clip_out_of_image_() # Clip off all parts of the LSs that are outside of an image in-place.
# 剪掉多余部分
new_anno = {'path': item['path'], 'lanes': self.linestrings_to_lanes(line_strings)} # 一个例子如下
# {'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492630158697138522/20.jpg', 'lanes': [array([[340. , 100.00001],
[334.5 , 105. ],
[329. , 110. ],
[327. , 115. ],
[326.5 , 120. ],
[325. , 124.99999],
[320.5 , 130. ],
[316.5 , 135. ],
[312.5 , 140. ],
[308.5 , 145. ],
[304. , 150. ],
[300. , 155. ],
[296. , 160. ],
[292. , 165. ],
[287.5 , 170. ],
[283.5 , 175. ],
[279.5 , 180. ],
[275. , 185. ],
[271. , 190. ],
[267. , 195. ],
[263. , 200.00002],
[258.5 , 204.99998],
[254.5 , 210. ],
[250.5 , 215. ],
[246.5 , 220. ],
[242. , 225. ],
[238. , 230. ],
[234. , 235. ],
[230. , 240. ],
[225.5 , 245.00002],
[221.5 , 249.99998],
[217.5 , 255. ],
[213.5 , 260. ],
[209. , 265. ],
[205. , 270. ],
[201. , 275. ],
[197. , 280. ],
[192.5 , 285. ],
[188.5 , 290. ],
[184.5 , 295. ],
[180.5 , 300. ],
[176. , 305. ],
[172. , 310. ],
[168. , 315. ],
[163.5 , 320. ],
[159.5 , 325. ],
[155.5 , 330. ],
[151.5 , 335. ],
[147. , 340. ],
[143. , 345. ],
[139. , 350. ]], dtype=float32), array([[346.5 , 100.00001],
[343.5 , 105. ],
[341. , 110. ],
[339.5 , 115. ],
[342. , 120. ],
[346. , 124.99999],
[350.5 , 130. ],
[357.5 , 135. ],
[364.5 , 140. ],
[372. , 145. ],
[379. , 150. ],
[386. , 155. ],
[393. , 160. ],
[400.5 , 165. ],
[407.5 , 170. ],
[414.5 , 175. ],
[421.5 , 180. ],
[429. , 185. ],
[436. , 190. ],
[443. , 195. ],
[450.5 , 200.00002],
[457.5 , 204.99998],
[464.5 , 210. ],
[471.5 , 215. ],
[479. , 220. ],
[486. , 225. ],
[493. , 230. ],
[500.5 , 235. ],
[507.5 , 240. ],
[514.5 , 245.00002],
[521.5 , 249.99998],
[529. , 255. ],
[536. , 260. ],
[543. , 265. ],
[550.5 , 270. ],
[557.5 , 275. ],
[564.5 , 280. ],
[571.5 , 285. ],
[579. , 290. ],
[586. , 295. ],
[593. , 300. ],
[600.5 , 305. ],
[607.5 , 310. ],
[614.5 , 315. ],
[621.5 , 320. ],
[629. , 325. ],
[636. , 330. ]], dtype=float32), array([[355.5 , 100.00001],
[354. , 105. ],
[353. , 110. ],
[352.5 , 115. ],
[358.5 , 120. ],
[364.5 , 124.99999],
[377. , 130. ],
[391.5 , 135. ],
[406. , 140. ],
[426.5 , 145. ],
[446.5 , 150. ],
[466.5 , 155. ],
[487. , 160. ],
[507. , 165. ],
[527. , 170. ],
[547. , 175. ],
[567.5 , 180. ],
[587.5 , 185. ],
[607.5 , 190. ],
[628. , 195. ]], dtype=float32)]}
new_anno['categories'] = item['categories'] #
label = self.transform_annotation(new_anno, img_wh=(self.img_w, self.img_h))['label'] # 转换标签。这里再次调用self.transform_annotation函数,之前是为了将图片由原始尺寸进行归一化到[0,1],现在时转换为目标尺寸[640,360]
img = img / 255.# {ndarray:(360,640,3)}
if self.normalize:
img = (img - IMAGENET_MEAN) / IMAGENET_STD # 数据处理,叫什么忘了。图像均值和标准差是文件开头给定的(应该是作者测好的)
img = self.to_tensor(img.astype(np.float32))
return (img, label, idx) # label:{ndarray:(5,115)}。 idx:图片的索引
def __len__(self):
return len(self.dataset) #
def main():
import torch
from lib.config import Config
cfg = Config('config.yaml')
train_dataset = cfg.get_dataset('train')
for idx in range(len(train_dataset)):
img = train_dataset.draw_annotation(idx)
cv2.imshow('sample', img)
if __name__ == "__main__":
import os
import json
import random
import numpy as np
from tabulate import tabulate
from utils.lane import LaneEval
from utils.metric import eval_json
'train+val': ['label_data_0313.json', 'label_data_0601.json', 'label_data_0531.json'],
'train': ['label_data_0313.json', 'label_data_0601.json'],
'val': ['label_data_0531.json'],
'test': ['test_label.json'],
class TuSimple(object):
def __init__(self, split='train', max_lanes=None, root=None, metric='default'):
self.split = split # 'val'
self.root = root # '/home/wqf/PolyLaneNet/PolyLaneNet-master'
self.metric = metric # 'default'
if split not in SPLIT_FILES.keys():
raise Exception('Split `{}` does not exist.'.format(split))
self.anno_files = [os.path.join(self.root, path) for path in SPLIT_FILES[split]] # ['/home/wqf/PolyLaneNet/PolyLaneNet-master/label_data_0531.json']
if root is None:
raise Exception('Please specify the root directory')
self.img_w, self.img_h = 1280, 720
self.max_points = 0
self.load_annotations() # 载入标注信息
# Force max_lanes, used when evaluating testing with models trained on other datasets
if max_lanes is not None:
self.max_lanes = max_lanes # 5
def get_img_heigth(self, path):
return 720
def get_img_width(self, path):
return 1280
def get_metrics(self, lanes, idx):
label = self.annotations[idx]
org_anno = label['old_anno']
pred = self.pred2lanes(org_anno['path'], lanes, org_anno['y_samples'])
_, _, _, matches, accs, dist = LaneEval.bench(pred, org_anno['org_lanes'], org_anno['y_samples'], 0, True)
return matches, accs, dist
def pred2lanes(self, path, pred, y_samples):
ys = np.array(y_samples) / self.img_h
lanes = []
for lane in pred: # lane:[ 1. 0.35354868 0.60016412 -0.21744023 -1.12229943 -0.06798992, 0.52022016]
# 一张图片中的一条车道线
if lane[0] == 0:
lane_pred = np.polyval(lane[3:], ys) * self.img_w # 网上俄教程解释不怎么,直接看源码注释。
# Evaluate a polynomial at specific values.
# If `p` is of length N, this function returns the value:
# ``p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1]``
# 具体来说就是实现了系数*y的对应次数
lane_pred[(ys < lane[1]) | (ys > lane[2])] = -2 # y比下限低或者比上限高的对应位置置为-2(将x)
return lanes # 返回图森格式的车道线x坐标
# 一个例子如下:
# [[-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 434.0222707652424, 416.55553232878447, 398.41377182098415, 379.59251516298053, 360.087288275913, 339.89361708092054, 319.0070274991427, 297.42304545171834, 275.13719685978685, 252.14500764448744, 228.44200372695923, 204.0237110283416, 178.88565546977338, 153.0233629723943, 126.43235945734297, 99.10817084575925, 71.0463230587817, 42.24234201755017, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0], [-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 527.3455002368396, 525.8638570457697, 524.135191365361, 522.1609893227596, 519.9427370451115, 517.4819206595625, 514.7800262932587, 511.8385400733462, 508.65894812697087, 505.2427365812786, 501.5913915634155, 497.7063992005276, 493.58924561976085, 489.2414169482611, 484.66439931317444, 479.85967884164694, 474.82874166082456, 469.57307389785313, 464.09416167987877, 458.3934911340475, 452.4725483875052, 446.3328195673979, 439.9757908008717, 433.40294821507234, 426.615777937146, 419.61576609423867, 412.40439881349624, 404.9831622220647, 397.35354244709015, 389.51702561571847, 381.47509785509567, 373.22924529236775, 364.7809540546807, 356.13171026918053, 347.2830000630131, 338.2363095633246, 328.9931248972608, 319.55493219196796, 309.92321757459183, 300.0994671722784, 290.08516711217385, 279.881803521424, 269.4908625271749, 258.9138302565725, 248.15219283676277, 237.20743639489172], [-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 599.3538418015662, 617.9884305596352, 636.5290950322511, 654.9729896887666, 673.3172689985347, 691.5590874309075, 709.6955994552383, 727.7239595408793, 745.6413221571835, 763.4448417735034, 781.1316728591919, 798.6989698836015, 816.143887316085, 833.463579625995, 850.6552012826844, 867.7159067555056, 884.6428505138115, 901.4331870269547, 918.084070764288, 934.5926561951637, 950.956097788935, 967.1715500149544, 983.2361673425744, 999.1471042411481, 1014.9015151800279, 1030.4965546285664, 1045.9293770561164, 1061.1971369320308, 1076.2969887256622, 1091.2260869063632, 1105.9815859434864, 1120.5606403063846, 1134.9604044644102, 1149.1780328869165, 1163.210680043256, 1177.0555004027808, 1190.7096484348444, 1204.170278608799, 1217.4345453939973, 1230.4996032597924, 1243.3626066755367, 1256.0207101105825, 1268.4710680342832, -2.0, -2.0, -2.0], [-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, 619.0527163349878, 663.2424223423004, 707.2017337919426, 750.9011041562089, 794.310986907394, 837.4018355177918, 880.1441034596971, 922.5082442054043, 964.4647112272077, 1005.9839579974018, 1047.0364379882812, 1087.5926046721404, 1127.6229115212732, 1167.0978120079744, 1205.987759604539, 1244.263207783261, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0]]
def load_annotations(self): # 执行这个函数后,self.annotations变成{list:358}
self.annotations = []
max_lanes = 0
for anno_file in self.anno_files:
with open(anno_file, 'r') as anno_obj:
lines = anno_obj.readlines() # 逐行读入。在这里,是{list:358}
for line in lines:
data = json.loads(line)
y_samples = data['h_samples'] # {list:56}[160, 170, ……, 710]
gt_lanes = data['lanes'] # {list:4}
lanes = [[(x, y) for (x, y) in zip(lane, y_samples) if x >= 0] for lane in gt_lanes] # 每条车道线的(x,y)。这里一共有4条。{list:4}[[(499, 240), (484, 250)……, (1272, 420)]]
lanes = [lane for lane in lanes if len(lane) > 0] # 剔除无效车道线。 4
max_lanes = max(max_lanes, len(lanes)) # 4
self.max_points = max(self.max_points, max([len(l) for l in gt_lanes])) # 56
'path': os.path.join(self.root, data['raw_file']), # 数据集路径
'org_path': data['raw_file'], # 原始图片路径
'org_lanes': gt_lanes, # 图森数据集中车道线原始x坐标
'lanes': lanes, # 有效车道线坐标(剔除-2等)
'aug': False,
'y_samples': y_samples # y坐标的值
}) #
# 一个例子如下:[{'path': '/home/wqf/PolyLaneNet/PolyLaneNet-master/clips/0531/1492626287507231547/20.jpg', 'org_path': 'clips/0531/1492626287507231547/20.jpg', 'org_lanes': [[-2, -2, -2, -2, -2, -2, -2, -2, 499, 484, 470, 453, 435, 418, 400, 374, 346, 318, 290, 262, 235, 207, 179, 151, 123, 96, 68, 40, 12, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, -2, 529, 531, 533, 536, 538, 540, 540, 538, 536, 531, 525, 519, 513, 507, 499, 491, 483, 475, 467, 459, 451, 443, 435, 426, 418, 410, 402, 394, 386, 378, 370, 362, 354, 346, 338, 330, 322, 314, 306, 297, 289, 281, 273, 265, 257, 249, 241], [-2, -2, -2, -2, -2, -2, -2, -2, 553, 568, 583, 598, 613, 640, 667, 693, 719, 740, 761, 783, 804, 825, 846, 868, 883, 897, 912, 926, 941, 955, 969, 984, 998, 1013, 1027, 1042, 1056, 1070, 1085, 1099, 1114, 1128, 1143, 1157, 1171, 1186, 1200, 1215, 1229, 1244, 1258, 1272, -2, -2, -2, -2], [-2, -2, -2, -2, -2, -2, -2, -2, 558, 585, 613, 646, 679, 714, 770, 817, 865, 912, 954, 994, 1033, 1073, 1113, 1153, 1193, 1232, 1272, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2]], 'lanes': [[(499, 240), (484, 250), (470, 260), (453, 270), (435, 280), (418, 290), (400, 300), (374, 310), (346, 320), (318, 330), (290, 340), (262, 350), (235, 360), (207, 370), (179, 380), (151, 390), (123, 400), (96, 410), (68, 420), (40, 430), (12, 440)], [(529, 250), (531, 260), (533, 270), (536, 280), (538, 290), (540, 300), (540, 310), (538, 320), (536, 330), (531, 340), (525, 350), (519, 360), (513, 370), (507, 380), (499, 390), (491, 400), (483, 410), (475, 420), (467, 430), (459, 440), (451, 450), (443, 460), (435, 470), (426, 480), (418, 490), (410, 500), (402, 510), (394, 520), (386, 530), (378, 540), (370, 550), (362, 560), (354, 570), (346, 580), (338, 590), (330, 600), (322, 610), (314, 620), (306, 630), (297, 640), (289, 650), (281, 660), (273, 670), (265, 680), (257, 690), (249, 700), (241, 710)], [(553, 240), (568, 250), (583, 260), (598, 270), (613, 280), (640, 290), (667, 300), (693, 310), (719, 320), (740, 330), (761, 340), (783, 350), (804, 360), (825, 370), (846, 380), (868, 390), (883, 400), (897, 410), (912, 420), (926, 430), (941, 440), (955, 450), (969, 460), (984, 470), (998, 480), (1013, 490), (1027, 500), (1042, 510), (1056, 520), (1070, 530), (1085, 540), (1099, 550), (1114, 560), (1128, 570), (1143, 580), (1157, 590), (1171, 600), (1186, 610), (1200, 620), (1215, 630), (1229, 640), (1244, 650), (1258, 660), (1272, 670)], [(558, 240), (585, 250), (613, 260), (646, 270), (679, 280), (714, 290), (770, 300), (817, 310), (865, 320), (912, 330), (954, 340), (994, 350), (1033, 360), (1073, 370), (1113, 380), (1153, 390), (1193, 400), (1232, 410), (1272, 420)]], 'aug': False, 'y_samples': [160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710]}]
if self.split == 'train': # 这里是val
print('total annos', len(self.annotations)) # 358
self.max_lanes = max_lanes # 记录整个数据集中一张图片中最多的车道线数量。5
def transform_annotations(self, transform):
self.annotations = list(map(transform, self.annotations))
def pred2tusimpleformat(self, idx, pred, runtime):
runtime *= 1000. # s to ms
img_name = self.annotations[idx]['old_anno']['org_path']
h_samples = self.annotations[idx]['old_anno']['y_samples']
lanes = self.pred2lanes(img_name, pred, h_samples)
output = {'raw_file': img_name, 'lanes': lanes, 'run_time': runtime}
return json.dumps(output) # 返回json格式文件
def save_tusimple_predictions(self, predictions, runtimes, filename):
lines = []
for idx in range(len(predictions)):
line = self.pred2tusimpleformat(idx, predictions[idx], runtimes[idx])
with open(filename, 'w') as output_file: # '/tmp/tusimple_predictions_default_2695.json'。注意你写入的文件路径
output_file.write('\n'.join(lines)) #
def eval(self, exp_dir, predictions, runtimes, label=None, only_metrics=False): # exp_dir:'experiments/default'
# predictions:(358, 5, 7) # runtimes:358
pred_filename = '/tmp/tusimple_predictions_{}.json'.format(label) # '/tmp/tusimple_predictions_default_2695.json'
self.save_tusimple_predictions(predictions, runtimes, pred_filename) #
if self.metric == 'default':
result = json.loads(LaneEval.bench_one_submit(pred_filename, self.anno_files[0])) # pred_filename:'/tmp/tusimple_predictions_default_2695.json'
# self.anno_files:['/home/wqf/PolyLaneNet/PolyLaneNet-master/label_data_0531.json']
elif self.metric == 'ours':
result = json.loads(eval_json(pred_filename, self.anno_files[0], json_type='tusimple'))
table = {}
for metric in result:
table[metric['name']] = [metric['value']]
table = tabulate(table, headers='keys')
if not only_metrics:
filename = 'tusimple_{}_eval_result_{}.json'.format(self.split, label)
with open(os.path.join(exp_dir, filename), 'w') as out_file:
json.dump(result, out_file)
return table, result
def __getitem__(self, idx):
return self.annotations[idx]
def __len__(self):
return len(self.annotations) # {ndarray:(358,)}
import sys
import numpy as np
from lib.datasets.lane_dataset import LaneDataset
EXPS_DIR = 'experiments'
class Evaluator(object):
def __init__(self, dataset, exp_dir, poly_degree=3):
self.dataset = dataset # {LaneDataset:358}
# self.predictions = np.zeros((len(dataset.annotations), dataset.max_lanes, 4 + poly_degree))
self.predictions = None
self.runtimes = np.zeros(len(dataset)) # 调用类中__len__()魔术方法。返回长度。{ndarray:(358,)}
self.loss = np.zeros(len(dataset)) # {ndarray:(358,)}
self.exp_dir = exp_dir # 'experiments/default'
self.new_preds = False
def add_prediction(self, idx, pred, runtime):
if self.predictions is None:
self.predictions = np.zeros((len(self.dataset.annotations), pred.shape[1], pred.shape[2])) # (358, 5, 7)
self.predictions[idx, :pred.shape[1], :] = pred # 把预测结果都放到self.predictions中
self.runtimes[idx] = runtime #
self.new_preds = True
def eval(self, **kwargs):# kwargs:{'label': 'default_2695'}
return self.dataset.dataset.eval(self.exp_dir, self.predictions, self.runtimes, **kwargs) # self.exp_dir:'experiments/default'
# self.predictions:(358, 5, 7) self.runtimes:358
if __name__ == "__main__":
evaluator = Evaluator(LaneDataset(split='test'), exp_dir=sys.argv[1])
import numpy as np
import ujson as json
from sklearn.linear_model import LinearRegression
class LaneEval(object):
lr = LinearRegression()
pixel_thresh = 20
pt_thresh = 0.85
def get_angle(xs, y_samples):
xs, ys = xs[xs >= 0], y_samples[xs >= 0]
if len(xs) > 1:
LaneEval.lr.fit(ys[:, None], xs)
k = LaneEval.lr.coef_[0]
theta = np.arctan(k)
theta = 0
return theta
def line_accuracy(pred, gt, thresh):
pred = np.array([p if p >= 0 else -100 for p in pred])
gt = np.array([g if g >= 0 else -100 for g in gt])
return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt)
def distances(pred, gt):
return np.abs(pred - gt)
def bench(pred, gt, y_samples, running_time, get_matches=False):
if any(len(p) != len(y_samples) for p in pred):
raise Exception('Format of lanes error.')
if running_time > 20000 or len(gt) + 2 < len(pred):
return 0., 0., 1.
angles = [LaneEval.get_angle(np.array(x_gts), np.array(y_samples)) for x_gts in gt]
threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles]
line_accs = []
fp, fn = 0., 0.
matched = 0.
my_matches = [False] * len(pred)
my_accs = [0] * len(pred)
my_dists = [None] * len(pred)
for x_gts, thresh in zip(gt, threshs):
accs = [LaneEval.line_accuracy(np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred]
my_accs = np.maximum(my_accs, accs)
max_acc = np.max(accs) if len(accs) > 0 else 0.
my_dist = [LaneEval.distances(np.array(x_preds), np.array(x_gts)) for x_preds in pred]
if len(accs) > 0:
my_dists[np.argmax(accs)] = {
'y_gts': list(np.array(y_samples)[np.array(x_gts) >= 0].astype(int)),
'dists': list(my_dist[np.argmax(accs)])
if max_acc < LaneEval.pt_thresh:
fn += 1
my_matches[np.argmax(accs)] = True
matched += 1
fp = len(pred) - matched
if len(gt) > 4 and fn > 0:
fn -= 1
s = sum(line_accs)
if len(gt) > 4:
s -= min(line_accs)
if get_matches:
return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(
min(len(gt), 4.), 1.), my_matches, my_accs, my_dists
return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.), 1.)
def bench_one_submit(pred_file, gt_file): # 这个函数是关于评估tusimple的,之前用过,这里忘记了,就先不看了,用到了再好好研究。
# pred_file:'/tmp/tusimple_predictions_default_2695.json'
# gt_file:'/home/wqf/PolyLaneNet/PolyLaneNet-master/label_data_0531.json'
json_pred = [json.loads(line) for line in open(pred_file).readlines()]
except BaseException as e:
raise Exception('Fail to load json file of the prediction.')
json_gt = [json.loads(line) for line in open(gt_file).readlines()]
if len(json_gt) != len(json_pred):
raise Exception('We do not get the predictions of all the test tasks')
gts = {l['raw_file']: l for l in json_gt}
accuracy, fp, fn = 0., 0., 0.
run_times = []
for pred in json_pred:
if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred:
raise Exception('raw_file or lanes or run_time not in some predictions.')
raw_file = pred['raw_file']
pred_lanes = pred['lanes']
run_time = pred['run_time']
if raw_file not in gts:
raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
gt = gts[raw_file]
gt_lanes = gt['lanes']
y_samples = gt['h_samples']
a, p, n = LaneEval.bench(pred_lanes, gt_lanes, y_samples, run_time)
except BaseException as e:
raise Exception('Format of lanes error.')
accuracy += a
fp += p
fn += n
num = len(gts)
# the first return parameter is the default ranking parameter
return json.dumps([{
'name': 'Accuracy',
'value': accuracy / num,
'order': 'desc'
}, {
'name': 'FP',
'value': fp / num,
'order': 'asc'
}, {
'name': 'FN',
'value': fn / num,
'order': 'asc'
}, {
'name': 'FPS',
'value': 1000. / np.mean(run_times)
if __name__ == '__main__':
import sys
if len(sys.argv) != 3:
raise Exception('Invalid input arguments')
print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2]))
except Exception as e:
# sys.exit(e.message)