车道线检测实现Ultra-Fast-Lane-Detection

车道线检测实现Ultra-Fast-Lane-Detection

原始仓库代码使用resnet系列作为骨干网络,网络输入分辨率为288*800,在板卡上运行7-8FPS,有点低,个人对原始仓库添加了些东西,具体如下:

  • 1.添加配置文件,去除原仓库复杂参数配置
  • 2.添加骨干网络mobilenetv2
  • 3.跟换了网络训练分辨率
  • 4.在训练脚本中添加验证脚本,选取best model
  • 5.添加onnx转换代码
    修改后仓库代码:https://github.com/ycdhqzhiai/Ultra-Fast-Lane-Detection

1.配置文件

dataset:
  name: CULane
  data_root: '/opt/sda5/BL01_Data/Lane_Data/CULane'
  num_lanes: 4
  w: 512
  h: 256
  input_size: [256, 512]
  batch_size: 128
  griding_num: 200
  use_aux: False
  # row_anchor: [121, 131, 141, 150, 160, 170, 180, 189, 199, 209, 219, 228, 238, 248, 258, 267, 277, 287] # 288
  row_anchor: [108, 116, 125, 134, 142, 151, 160, 168, 177, 186, 194, 203, 212, 220, 229, 238, 246, 255] #256
  num_per_lane: 18
  num_workers: 8

train:
  epoch: 100
  optimizer: 'SGD'  #['SGD','Adam']
  learning_rate: 0.1
  weight_decay: 1.0e-4
  momentum: 0.9

  scheduler: 'multi' #['multi', 'cos']
  steps: [25,38]
  gamma: 0.1
  warmup: 'linear'
  warmup_iters: 2
  

network:
  backbone: 'mobilenetv2'
  pretrained: NULL
  out_channel: [32, 96, 320] #[128,256,1024]

sim_loss_w: 0.0
shp_loss_w: 0.0

test:
  test_model: 'weights/20210322_094501_lr_0.1/ep016.pth'
  test_work_dir: 'weights'
  val_intervals: 1
# EXP
note: ''
log_path: 'runs'
view: True
# FINETUNE or RESUME MODEL PATH
finetune: NULL
resume: NULL

所有参数统一放在配置文件中,原始仓库在两个py文件中
2.骨干网络和分辨率

import copy
from .mobilenetv2 import MobileNetV2
from .resnet import resnet

def build_backbone(name):
    if name == 'resnet':
        layer = name.split('_')[1]
        return resnet(layer)
    elif name == 'mobilenetv2':
        return MobileNetV2()
    else:
        raise NotImplementedError
import math
lane_num = 18

for i in range(1, lane_num + 1):
    anchors = (590-(i-1)*20)-1
    anchors = math.floor((256 / 590) * anchors)
    print(anchors)

通过配置文件中的name直接选取不同网络和不同分辨率,注意改变分辨率同时,需要首先生成anchor,将配置文件中的anchors修改为对应分辨率。
3.在训练脚本中添加test功能,选取最佳模型

def test(net, data_loader, dataset, work_dir, logger, use_aux=True):
    output_path = os.path.join(work_dir, 'culane_eval_tmp')
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    net.eval()
    if dataset['name'] == 'CULane':
        for i, data in enumerate(dist_tqdm(data_loader)):
            imgs, names = data
            imgs = imgs.cuda()
            with torch.no_grad():
                out = net(imgs)
            if len(out) == 2 and use_aux:
                out, seg_out = out

            generate_lines(out,imgs[0,0].shape,names,output_path,dataset['griding_num'],localization_type = 'rel',flip_updown = True)
        res = call_culane_eval(dataset['data_root'], 'culane_eval_tmp', work_dir)
        TP,FP,FN = 0,0,0
        for k, v in res.items():
            val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0
            val_tp,val_fp,val_fn = int(v['tp']),int(v['fp']),int(v['fn'])
            TP += val_tp
            FP += val_fp
            FN += val_fn
            logger.log('k:{} val{}'.format(k,val))
        P = TP * 1.0/(TP + FP)
        R = TP * 1.0/(TP + FN)
        F = 2*P*R/(P + R)
        logger.log('F:{}'.format(F))
        return F

4.onnx转换脚本

import torch, os, cv2
from model.model import parsingNet
import torch
import scipy.special, tqdm
import numpy as np
import argparse
import yaml
def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--params', default = 'configs/culane.yaml', type = str)
    parser.add_argument('--batch_size', default = 1, type = int)
    parser.add_argument('--weights', default = 'model_last.pth', type = str)
    parser.add_argument('--img-size', nargs='+', type=int, default=[256, 512], help='image size')  # height, width
    return parser

if __name__ == '__main__':
    args = get_args().parse_args()
    args.img_size *= 2 if len(args.img_size) == 1 else 1  # expand

    with open(args.params) as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)  # data dict

    net = parsingNet(network=cfg['network'],datasets=cfg['dataset']).cuda()

    state_dict = torch.load(args.weights, map_location='cpu')['model']
    compatible_state_dict = {
     }
    for k, v in state_dict.items():
        print(k)
        if 'module.' in k:
            compatible_state_dict[k[7:]] = v
        else:
            compatible_state_dict[k] = v
    exit()
    net.load_state_dict(compatible_state_dict, strict=False)
    net.eval()
    print('val done!!!')

    img = torch.zeros(args.batch_size, 3, *args.img_size)  # image size(1,3,320,192) iDetection
    img = img.cuda()
    with torch.no_grad():
        out = net(img)
    # ONNX export
    try:
        import onnx
        from onnxsim import simplify

        print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
        f = args.weights.replace('.pth', '.onnx')  # filename
        torch.onnx.export(net, img, f, verbose=False, opset_version=11, input_names=['images'],
                          output_names=['output'])

        # Checks
        onnx_model = onnx.load(f)  # load onnx model
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, f)
        print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
        print('ONNX export success, saved as %s' % f)
    except Exception as e:
        print('ONNX export failure: %s' % e)

你可能感兴趣的:(自动驾驶,自动驾驶,pytorch)