Ultra-Fast-Lane-Detection-v2 {后处理优化}//参考

采用三次多项式拟合生成的anchor特征点,在给定的polyfit_draw函数中,degree参数代表了拟合多项式的度数。

具体来说,当我们使用np.polyfit函数进行数据点的多项式拟合时,我们需要指定一个度数。这个度数决定了多项式的复杂度。例如:

  • degree = 1:线性拟合,也就是最简单的直线拟合。拟合的多项式形式为 f(y)=ax+b。

  • degree = 2:二次多项式拟合。拟合的多项式形式为 f(y)=ax2+bx+c。

  • degree = 3:三次多项式拟合。拟合的多项式形式为 f(y)=ax3+bx2+cx+d。

...以此类推。

度数越高,多项式越复杂,可以更准确地拟合数据点,但也更容易过拟合(即模型过于复杂,过于依赖训练数据,对新数据的适应性差)。

import torch, os, cv2
from utils.dist_utils import dist_print
import torch, os
from utils.common import merge_config, get_model
import tqdm
import torchvision.transforms as transforms
from data.dataset import LaneTestDataset

def pred2coords(pred, row_anchor, col_anchor, local_width = 1, original_image_width = 1640, original_image_height = 590):
    batch_size, num_grid_row, num_cls_row, num_lane_row = pred['loc_row'].shape
    batch_size, num_grid_col, num_cls_col, num_lane_col = pred['loc_col'].shape

    max_indices_row = pred['loc_row'].argmax(1).cpu()
    # n , num_cls, num_lanes
    valid_row = pred['exist_row'].argmax(1).cpu()
    # n, num_cls, num_lanes

    max_indices_col = pred['loc_col'].argmax(1).cpu()
    # n , num_cls, num_lanes
    valid_col = pred['exist_col'].argmax(1).cpu()
    # n, num_cls, num_lanes

    pred['loc_row'] = pred['loc_row'].cpu()
    pred['loc_col'] = pred['loc_col'].cpu()

    coords = []

    row_lane_idx = [1,2]
    col_lane_idx = [0,3]

    for i in row_lane_idx:
        tmp = []
        if valid_row[0,:,i].sum() > num_cls_row / 2:
            for k in range(valid_row.shape[1]):
                if valid_row[0,k,i]:
                    all_ind = torch.tensor(list(range(max(0,max_indices_row[0,k,i] - local_width), min(num_grid_row-1, max_indices_row[0,k,i] + local_width) + 1)))
                    
                    out_tmp = (pred['loc_row'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
                    out_tmp = out_tmp / (num_grid_row-1) * original_image_width
                    tmp.append((int(out_tmp), int(row_anchor[k] * original_image_height)))
            coords.append(tmp)

    for i in col_lane_idx:
        tmp = []
        if valid_col[0,:,i].sum() > num_cls_col / 4:
            for k in range(valid_col.shape[1]):
                if valid_col[0,k,i]:
                    all_ind = torch.tensor(list(range(max(0,max_indices_col[0,k,i] - local_width), min(num_grid_col-1, max_indices_col[0,k,i] + local_width) + 1)))
                    
                    out_tmp = (pred['loc_col'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5

                    out_tmp = out_tmp / (num_grid_col-1) * original_image_height
                    tmp.append((int(col_anchor[k] * original_image_width), int(out_tmp)))
            coords.append(tmp)

    return coords

def polyfit_draw(img, coords, degree=3, color=(144, 238, 144), thickness=2):
    """
    对车道线坐标进行多项式拟合并在图像上绘制曲线。
    :param img: 输入图像
    :param coords: 车道线坐标列表
    :param degree: 拟合的多项式的度数
    :param color: 曲线的颜色
    :param thickness: 曲线的宽度
    :return: 绘制了曲线的图像
    """
    if len(coords) == 0:
        return img

    x = [point[0] for point in coords]
    y = [point[1] for point in coords]

    # 对点进行多项式拟合
    coefficients = np.polyfit(y, x, degree)

    poly = np.poly1d(coefficients)

    ys = np.linspace(min(y), max(y), 100)
    xs = poly(ys)

    for i in range(len(ys) - 1):
        start_point = (int(xs[i]), int(ys[i]))
        end_point = (int(xs[i+1]), int(ys[i+1]))
        cv2.line(img, start_point, end_point, color, thickness)

    return img

if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True

    args, cfg = merge_config()
    cfg.batch_size = 1
    print('setting batch_size to 1 for demo generation')

    dist_print('start testing...')
    assert cfg.backbone in ['18','34','50','101','152','50next','101next','50wide','101wide']

    if cfg.dataset == 'CULane':
        cls_num_per_lane = 18
    elif cfg.dataset == 'Tusimple':
        cls_num_per_lane = 56
    else:
        raise NotImplementedError

    net = get_model(cfg)

    state_dict = torch.load(cfg.test_model, map_location='cpu')['model']
    compatible_state_dict = {}
    for k, v in state_dict.items():
        if 'module.' in k:
            compatible_state_dict[k[7:]] = v
        else:
            compatible_state_dict[k] = v

    net.load_state_dict(compatible_state_dict, strict=False)
    net.eval()

    img_transforms = transforms.Compose([
        transforms.Resize((int(cfg.train_height / cfg.crop_ratio), cfg.train_width)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    if cfg.dataset == 'CULane':
        splits = ['test0_normal.txt']
        datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, 'list/test_split/'+split),img_transform = img_transforms, crop_size = cfg.train_height) for split in splits]
        img_w, img_h = 1570, 660
    elif cfg.dataset == 'Tusimple':
        splits = ['test.txt']
        datasets = [LaneTestDataset(cfg.data_root,os.path.join(cfg.data_root, split),img_transform = img_transforms, crop_size = cfg.train_height) for split in splits]
        img_w, img_h = 1280, 720
    else:
        raise NotImplementedError
    for split, dataset in zip(splits, datasets):
        loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle = False, num_workers=1)
        fourcc = cv2.VideoWriter_fourcc(*'MJPG')
        print(split[:-3]+'avi')
        vout = cv2.VideoWriter('4.'+'avi', fourcc , 30.0, (img_w, img_h))
        for i, data in enumerate(tqdm.tqdm(loader)):
            imgs, names = data
            imgs = imgs.cuda()
            with torch.no_grad():
                pred = net(imgs)

            vis = cv2.imread(os.path.join(cfg.data_root,names[0]))
            coords = pred2coords(pred, cfg.row_anchor, cfg.col_anchor, original_image_width = img_w, original_image_height = img_h)
            for lane in coords:
#                 for coord in lane:
#                     cv2.circle(vis,coord,1,(0,255,0),-1)
#             vis = draw_lanes(vis, coords)
#             polyfit_draw(vis, lane)
                vis = polyfit_draw(vis, lane)  # 对每一条车道线都使用polyfit_draw函数
            vout.write(vis)
        vout.release()

    

 ps:

优化前

Ultra-Fast-Lane-Detection-v2 {后处理优化}//参考_第1张图片

优化后

Ultra-Fast-Lane-Detection-v2 {后处理优化}//参考_第2张图片

显存利用情况

Ultra-Fast-Lane-Detection-v2 {后处理优化}//参考_第3张图片 

你可能感兴趣的:(算法,机器学习,人工智能)