道路线检测 lanenet_pytorch

最近在学习道路线检测,在网上找了很多资源,在github上找到的开源的LaneNet项目数目较少,只有基于tensoflow 1.x,但是在配置环境的时候过于麻烦,同时由于tensorflow 2.x的缘故,在源代码上修改的时候也挺麻烦的,动不动就报错,最后也没跑起来,且相关作者也已不再维护。最后发现了一个基于pytorch的LaneNet的源码,并在其基础上进行了修改。

https://github.com/IrohXu/lanenet-lane-detection-pytorch

从github上获取到源代码后,通过python test.py --img ./data/tusimple_test_image/0.jpg 进行测试。

可以读取文件夹下的图片进行预测。

import argparse
import time
import os
import sys

import torch
from dataloader.transformers import Rescale
from model.lanenet.LaneNet import LaneNet
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms
import numpy as np
from PIL import Image
import pandas as pd
import cv2

# GPU or CPU
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def load_test_data(img_path, transform):
    img = Image.open(img_path)
    img = transform(img)
    return img


def test():
    # 创建test_output文件夹
    if os.path.exists('test_output') == False:
        os.mkdir('test_output')
    args = parse_args()
    # input图片地址
    img_path = args.img

    # # resize后的图片大小
    # resize_height = args.height
    # resize_width = args.width

    # 图像处理
    data_transform = transforms.Compose([
        # transforms.Resize((resize_height, resize_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # 模型参数
    model_path = args.model
    # 模型结构
    model = LaneNet(arch=args.model_type)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.eval()
    model.to(DEVICE)
    print('模型加载成功,开始对图片进行预测')
    # 不计算参数梯度,防止出现因图片size问题出现CUDA out of memory
    with torch.no_grad():
        imgs = os.listdir(img_path)
        imgs.sort(key=lambda x: int(x.split('.')[0]))
        for img in imgs:
            dummy_input = load_test_data(img_path + '/' + img, data_transform).to(DEVICE)
            dummy_input = torch.unsqueeze(dummy_input, dim=0)
            outputs = model(dummy_input)

            # input = Image.open(img_path)
            # input = input.resize((resize_width, resize_height))
            # input = np.array(input)

            instance_pred = torch.squeeze(outputs['instance_seg_logits'].detach().to('cpu')).numpy() * 255
            binary_pred = torch.squeeze(outputs['binary_seg_pred']).to('cpu').numpy() * 255
            # # 保存输入图片
            # cv2.imwrite(os.path.join('test_output', 'input.jpg'), input)

            # cv2.imwrite(os.path.join('test_output', img.split('.')[0] + '_instance.jpg'),
            #             instance_pred.transpose((1, 2, 0)))
            # 保存二值图
            cv2.imwrite(os.path.join('test_output', img.split('.')[0] + '_binary.jpg'), binary_pred)
            print(img.split('.'))

    print('over')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--img", default="./video2imgs", help="Img path")
    parser.add_argument("--model_type", help="Model type", default='ENet')
    parser.add_argument("--model", help="Model path", default='./log/best_model.pth')
    parser.add_argument("--width", required=False, type=int, help="Resize width", default=512)
    parser.add_argument("--height", required=False, type=int, help="Resize height", default=256)
    parser.add_argument("--save", help="Directory to save output", default="./test_output")
    return parser.parse_args()


# def img2video(input_root, output_root):
#     img_root = input_root  # 读取图片目录
#     fps = 30  # 保存视频的FPS,可以适当调整
# 
#     # 编码器 可以用(*'DVIX')或(*'X264'),如果都不行先装ffmepg: sudo apt-get install ffmepg
#     fourcc = cv2.VideoWriter_fourcc(*'XVID')
# 
#     videoWriter = cv2.VideoWriter(output_root + '/predict.mp4', fourcc, fps,
#                                   (1280, 720))  # 视频写入;编码器;fps;图片的尺寸,根据自己的图片决定
#     # 遍历文件夹下所有图片,listdir为随机排序
#     imgnames = os.listdir(img_root)
#     # 将图片顺序排序
#     imgnames.sort(key=lambda x: int(x[:-4]))
#     for imgname in imgnames:
#         print(imgname)
#         # 读取图片
#         frame = cv2.imread(img_root + '/' + imgname)
#         videoWriter.write(frame)
#     videoWriter.release()
#     print("已经转为视频")


if __name__ == "__main__":
    test()
    # img2video('D:/lanenet-lane-detection-pytorch-main/test_output', 'D:/lanenet-lane-detection-pytorch-main/imgs2video')

默认input路径下全为图片,没有对路径下的文件进行判断是否为图片,有待完善 

你可能感兴趣的:(深度学习,人工智能)