resa车道线检测代码阅读

以基于tusimple数据集的模型训练为例子,首先数据集的预处理
每当对一张图片进行读取,会截取原图和标签的下大半部分,然后进行数据增强(包括随机旋转、随机水平移动、改变指定的尺寸和归一化)

    def __getitem__(self, idx):
        img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
        img = img[self.cfg.cut_height:, :, :]     # 裁剪图片的一部分作为训练数据集

        if self.is_training:
            label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
            if len(label.shape) > 2:
                label = label[:, :, 0]
            # print(label.shape)   # (720, 1280)
            label = label.squeeze()
            # print(label.shape)   # (720, 1280)/.
            label = label[self.cfg.cut_height:, :]
            exist = self.exist_list[idx]
            if self.transform:
                img, label = self.transform((img, label))
            label = torch.from_numpy(label).contiguous().long()                     # 标签的数据格式应该为long()
        else:
            img, = self.transform((img,))

        img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()   # 原图的数据格式为float()
        meta = {'full_img_path': self.full_img_path_list[idx],                            
                'img_name': self.img_name_list[idx]}

        data = {'img': img, 'meta': meta}
        if self.is_training:
            data.update({'label': label, 'exist': exist})
        return data
    def transform_train(self):
        input_mean = self.cfg.img_norm['mean']
        train_transform = torchvision.transforms.Compose([
            tf.GroupRandomRotation(),     # 随机旋转
            tf.GroupRandomHorizontalFlip(),	#随机水平偏移
            tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),	#修改为指定尺寸
            tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(		# 归一化
                self.cfg.img_norm['std'], (1, ))),
        ])
        return train_transform

网络模型的代码

import torch.nn as nn
import torch
import torch.nn.functional as F

from models.registry import NET
from .resnet import ResNetWrapper
from .decoder import BUSD, PlainDecoder


class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        self.iter = cfg.resa.iter
        chan = cfg.resa.input_channel
        fea_stride = cfg.backbone.fea_stride
        self.height = cfg.img_height // fea_stride
        self.width = cfg.img_width // fea_stride
        self.alpha = cfg.resa.alpha
        conv_stride = cfg.resa.conv_stride

        for i in range(self.iter):
            conv_vert1 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)
            conv_vert2 = nn.Conv2d(
                chan, chan, (1, conv_stride),
                padding=(0, conv_stride//2), groups=1, bias=False)

            setattr(self, 'conv_d'+str(i), conv_vert1)
            setattr(self, 'conv_u'+str(i), conv_vert2)

            conv_hori1 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)
            conv_hori2 = nn.Conv2d(
                chan, chan, (conv_stride, 1),
                padding=(conv_stride//2, 0), groups=1, bias=False)

            setattr(self, 'conv_r'+str(i), conv_hori1)
            setattr(self, 'conv_l'+str(i), conv_hori2)

            idx_d = (torch.arange(self.height) + self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_d'+str(i), idx_d)

            idx_u = (torch.arange(self.height) - self.height //
                     2**(self.iter - i)) % self.height
            setattr(self, 'idx_u'+str(i), idx_u)

            idx_r = (torch.arange(self.width) + self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_r'+str(i), idx_r)

            idx_l = (torch.arange(self.width) - self.width //
                     2**(self.iter - i)) % self.width
            setattr(self, 'idx_l'+str(i), idx_l)

    def forward(self, x):
        x = x.clone()

        for direction in ['d', 'u']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx, :])))

        for direction in ['r', 'l']:
            for i in range(self.iter):
                conv = getattr(self, 'conv_' + direction + str(i))
                idx = getattr(self, 'idx_' + direction + str(i))
                x.add_(self.alpha * F.relu(conv(x[..., idx])))

        return x



class ExistHead(nn.Module):
    def __init__(self, cfg=None):
        super(ExistHead, self).__init__()
        self.cfg = cfg

        self.dropout = nn.Dropout2d(0.1)  # ???
        self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)

        stride = cfg.backbone.fea_stride * 2
        self.fc9 = nn.Linear(
            int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
        self.fc10 = nn.Linear(128, cfg.num_classes-1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.conv8(x)

        x = F.softmax(x, dim=1)
        x = F.avg_pool2d(x, 2, stride=2, padding=0)
        x = x.view(-1, x.numel() // x.shape[0])
        x = self.fc9(x)
        x = F.relu(x)
        x = self.fc10(x)
        x = torch.sigmoid(x)

        return x


@NET.register_module
class RESANet(nn.Module):
    def __init__(self, cfg):
        super(RESANet, self).__init__()
        self.cfg = cfg
        self.backbone = ResNetWrapper(cfg)
        self.resa = RESA(cfg)
        self.decoder = eval(cfg.decoder)(cfg)
        self.heads = ExistHead(cfg)

    def forward(self, batch):
        fea = self.backbone(batch)
        fea = self.resa(fea)
        seg = self.decoder(fea)
        exist = self.heads(fea)

        output = {'seg': seg, 'exist': exist}

        return output

损失函数
可用交叉熵损失函数或者dice_loss损失函数

import torch.nn as nn
import torch
import torch.nn.functional as F

from runner.registry import TRAINER

def dice_loss(input, target):
    input = input.contiguous().view(input.size()[0], -1)   # 改变input和target的矩阵格式
    target = target.contiguous().view(target.size()[0], -1).float()

    a = torch.sum(input * target, 1)
    b = torch.sum(input * input, 1) + 0.001 # 避免|X|和|Y|都为0时,分母为零;同时减少过拟合的可能性
    c = torch.sum(target * target, 1) + 0.001
    d = (2 * a) / (b + c)
    return (1-d).mean()

@TRAINER.register_module
class RESA(nn.Module):
    def __init__(self, cfg):
        super(RESA, self).__init__()
        self.cfg = cfg
        self.loss_type = cfg.loss_type
        if self.loss_type == 'cross_entropy':
            weights = torch.ones(cfg.num_classes)
            weights[0] = cfg.bg_weight
            weights = weights.cuda()
            self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
                                              weight=weights).cuda()

        self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()

    def forward(self, net, batch):
        output = net(batch['img'])

        loss_stats = {}
        loss = 0.

        if self.loss_type == 'dice_loss':
            target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
            seg_loss = dice_loss(F.softmax(
                output['seg'], dim=1)[:, 1:], target[:, 1:])
        else:
            seg_loss = self.criterion(F.log_softmax(
                output['seg'], dim=1), batch['label'].long())

        loss += seg_loss * self.cfg.seg_loss_weight

        loss_stats.update({'seg_loss': seg_loss})

        if 'exist' in output:
            exist_loss = 0.1 * \
                self.criterion_exist(output['exist'], batch['exist'].float())
            loss += exist_loss
            loss_stats.update({'exist_loss': exist_loss})

        ret = {'loss': loss, 'loss_stats': loss_stats}

        return ret

开始训练,运行以下代码

python main.py configs/tusimple.py --gpus 0

用自己的视频来测试训练好的模型的性能

import os
import os.path as osp
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner
from datasets import build_dataloader

color_list =[
(255, 0, 0),
(255, 225, 0),
(255, 0, 255),
(125, 125, 125),
(255, 125, 125),
(0, 125, 0)
]
def main():
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)

    cfg = Config.fromfile(args.config)
    cfg.gpus = len(args.gpus)
    cfg.load_from = args.load_from
    cfg.finetune_from = args.finetune_from
    cfg.view = args.view

    cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type

    cudnn.benchmark = True
    cudnn.fastest = True

    runner = Runner(cfg)

    runner.net.eval()
    val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
    def to_cuda(batch):
        for k in batch:
            if k == 'meta':
                continue
            batch[k] = batch[k].cuda()
        return batch
    def is_short(lane):
        start = [i for i, x in enumerate(lane) if x > 0]
        if not start:
            return 1
        else:
            return 0
    def probmap2lane( seg_pred, exist, b, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
        """
        Arguments:
        ----------
        seg_pred:      np.array size (5, h, w)
        resize_shape:  reshape size target, (H, W)
        exist:       list of existence, e.g. [0, 1, 1, 0]
        smooth:      whether to smooth the probability or not
        y_px_gap:    y pixel gap for sampling
        pts:     how many points for one lane
        thresh:  probability threshold

        Return:
        ----------
        coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
        """
        if resize_shape is None:
            resize_shape = seg_pred.shape[1:]  # seg_pred (5, h, w)
        _, h, w = seg_pred.shape
        H, W = resize_shape
        coordinates = []
        a = 0
        for i in range(cfg.num_classes - 1):
            prob_map = seg_pred[i + 1]  # seg_pred[0]:背景
            if smooth:
                prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)

            coords = get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
            # print(exist)
            # if (int)(b[i]) == 0:  # if (int)(exist[i])==0:
            #     continue

            if is_short(coords):
                continue
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
            # if (int)(exist[i])==1:
            #     a =a+1
            #     if a==2:
            #         break

        if len(coordinates) == 0:
            coords = np.zeros(pts)
            coordinates.append(
                [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
                 range(pts)])
        # print(coordinates)

        return coordinates
    def fix_gap(coordinate):
        if any(x > 0 for x in coordinate):
            start = [i for i, x in enumerate(coordinate) if x > 0][0]
            end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
            lane = coordinate[start:end+1]
            if any(x < 0 for x in lane):
                gap_start = [i for i, x in enumerate(
                    lane[:-1]) if x > 0 and lane[i+1] < 0]
                gap_end = [i+1 for i,
                           x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
                gap_id = [i for i, x in enumerate(lane) if x < 0]
                if len(gap_start) == 0 or len(gap_end) == 0:
                    return coordinate
                for id in gap_id:
                    for i in range(len(gap_start)):
                        if i >= len(gap_end):
                            return coordinate
                        if id > gap_start[i] and id < gap_end[i]:
                            gap_width = float(gap_end[i] - gap_start[i])
                            lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
                                gap_end[i] - id) / gap_width * lane[gap_start[i]])
                if not all(x > 0 for x in lane):
                    print("Gaps still exist!")
                coordinate[start:end+1] = lane
        return coordinate
    def get_lane(prob_map, y_px_gap, pts, thresh, resize_shape=None):
        """
        Arguments:
        ----------
        prob_map: prob map for single lane, np array size (h, w)
        resize_shape:  reshape size target, (H, W)

        Return:
        ----------
        coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
        """
        if resize_shape is None:
            resize_shape = prob_map.shape
        h, w = prob_map.shape
        H, W = resize_shape
        H -= cfg.cut_height

        coords = np.zeros(pts)
        coords[:] = -1.0
        for i in range(pts):
            y = int((H - 10 - i * y_px_gap) * h / H)
            if y < 0:
                break
            line = prob_map[y, :]
            id = np.argmax(line)
            if line[id] > thresh:
                coords[i] = int(id / w * W)
        if (coords > 0).sum() < 2:
            coords = np.zeros(pts)
        fix_gap(coords)
        # print(coords.shape)

        return coords
    def view(img, coords, file_path=None):
        i=0
        for coord in coords:
            for x, y in coord:
                if x <= 0 or y <= 0:
                    continue
                x, y = int(x), int(y)
                cv2.circle(img, (x, y), 4, color_list[i], 2)
            i = i+1

        # if file_path is not None:
        #     if not os.path.exists(osp.dirname(file_path)):
        #         os.makedirs(osp.dirname(file_path))
        #     cv2.imwrite(file_path, img)
    import time
    time_start = time.clock()
    fps = 0.0
    capture = cv2.VideoCapture("/media/gooddz/新加卷/检测视频/极弯场景.mp4")
    import torchvision
    import utils.transforms as tf
    def transform_val():
        val_transform = torchvision.transforms.Compose([
            tf.SampleResize((640, 368)),
            tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0, )), std=(
                [1., 1., 1.], (1, ))),
        ])
        return val_transform
    while (True):
        t1 = time.time()
        ref,frame = capture.read()
        # img_test1 = cv.resize(img, (int(y / 2), int(x / 2)))
        frame = cv2.resize(frame,(1280,720))
        frame_copy = frame.copy()
        frame = frame[160:, :, :]
        # print(type(frame))
        # frame = frame[None,:]
        # val_transform = transforms.Compose([
        #     tf.SampleResize((640, 368)),
        #     tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0,)), std=(
        #         [1., 1., 1.], (1,))),
        # ])
        # print(frame.shape)
        transform = transform_val()

        frame = transform((frame,))
        # print(frame, "zzz")
        # print(frame[0].shape)
        frame = torch.from_numpy(frame[0]).permute(2, 0, 1).contiguous().float()
        frame = torch.tensor(frame)
        # print(frame.shape)
        frame = frame.unsqueeze(0).float()
        frame = frame.cuda()
        with torch.no_grad():
            # print(data['img'])
            output = runner.net(frame)
            # print(output)
            seg_pred, exist_pred = output['seg'], output['exist']

            # a = output['exist_lane']
            # _, b_1 = torch.max(F.softmax(a, dim=2), 2)
            # print(F.softmax(a, dim=1),b)
            # a = F.softmax(a, dim=0)
            # print(b,b.shape)
            # s = torch.argmax(seg_pred[0],0)
            # s = s.detach().cpu().numpy()
            # dst_binary_image = np.zeros([s.shape[0], s.shape[1]], np.uint8)
            # for y in range(s.shape[0]):
            #     for x in range(s.shape[1]):
            #         dst_binary_image[y,x] = (s[y,x]*40)
            # cv2.imshow("zz",dst_binary_image)
            # cv2.waitKey(5)
            seg_pred = F.softmax(seg_pred, dim=1)

            seg_pred = seg_pred.detach().cpu().numpy()
            exist_pred = exist_pred.detach().cpu().numpy()
            # print(b, b.shape, exist_pred, exist_pred.shape)
            for b in range(len(seg_pred)):
                seg = seg_pred[b]
                # print(len(seg_pred))
                exist_1 = [1 if exist_pred[b, i] >
                                0.5 else 0 for i in range(cfg.num_classes - 1)]

                lane_coords = probmap2lane(seg, exist_1, thresh=0.6, b=exist_1[b])
                # print(lane_coords)
                for i in range(len(lane_coords)):
                    lane_coords[i] = sorted(
                        lane_coords[i], key=lambda pair: pair[1])
            # frame = np.array(frame)
            # print(lane_coords)
            # print(frame_copy.shape, type(frame_copy))
            view(frame_copy, lane_coords)
            # frame = frame[0].permute([1, 2, 0])
            # (720, 1280, 3)

            # print(frame.shape)
            fps = (fps + (1. / (time.time() - t1))) / 2
            # print(frame[0].shape,frame)
            # frame_copy = frame_copy.astype(np.uint8)
            # cv2.namedWindow('imshow', cv2.WINDOW_NORMAL)
            cv2.imshow('imshow', frame_copy)
            cv2.waitKey(1)
            print("fps:", fps)
    cv2.destroyAllWindows()
    time_end = time.clock()
    print(time_end-time_start)
def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--work_dirs', type=str, default='work_dirs',
        help='work dirs')
    parser.add_argument(
        '--load_from', default='/home/llgj/桌面/ldz/resa-main_原/work_dirs/TuSimple/20220120_083126_lr_2e-02_b_4/ckpt/best.pth')
    parser.add_argument(
        '--finetune_from', default=None,
        help='whether to finetune from the checkpoint')
    parser.add_argument(
        '--validate',
        action='store_true',
        help='whether to evaluate the checkpoint during training')
    parser.add_argument(
        '--view',
        action='store_true',
        help='whether to show visualization result')
    parser.add_argument('--gpus', nargs='+', type=int, default='0')
    parser.add_argument('--seed', type=int,
                        default=None, help='random seed')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    main()

#configs/tusimple.py --gpus 0

#configs/tusimple.py --validate --load_from /media/gooddz/学习/culane_resnet50.pth --gpus 0 --view

新建test.py,复制上面的代码,运行以下代码进行测试

python test.py configs/tusimple.py --validate --load_from /media/gooddz/学习/tusimple.pth --gpus 0 --view

你可能感兴趣的:(论文,深度学习,计算机视觉,人工智能)