Ultra-Fast-Lane-Detection 复现

源码地址:https://github.com/cfzd/Ultra-Fast-Lane-Detection
论文地址:https://arxiv.org/abs/2004.11757

1 数据集准备

下载地址:https://github.com/TuSimple/tusimple-benchmark/issues/3
Ultra-Fast-Lane-Detection 复现_第1张图片
下载的train_set.zip、test_set.zip和test_label.json,并解压
下载完成后,新建一个文件夹Tusimple,将上面多个放在该文件夹里面
Ultra-Fast-Lane-Detection 复现_第2张图片

1 源码下载

 git clone https://github.com/cfzd/Ultra-Fast-Lane-Detection
  cd Ultra-Fast-Lane-Detection

2 环境搭建

conda create -n lane python=3.7
conda activate ultra_fast_lane  # 激活环境
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio===0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple  # 若存在个别库安装缓慢,则可以选中单独安装
# 特别注意setuptools版本不能过高
pip install setuptools==59.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

3.代码修改

需要修改三大部分,data_root、log_path和test_work_dir

# DATA
dataset='Tusimple'
data_root = '/home/gooddz/coda/论文对比实验/Tusimple'   # 修改读取数据集的地址

# TRAIN
epoch = 100
batch_size = 8
optimizer = 'Adam'    #['SGD','Adam']
# learning_rate = 0.1
learning_rate = 4e-4
weight_decay = 1e-4
momentum = 0.9

scheduler = 'cos'     #['multi', 'cos']
# steps = [50,75]
gamma  = 0.1
warmup = 'linear'
warmup_iters = 100

# NETWORK
backbone = '18'
griding_num = 100
use_aux = True

# LOSS
sim_loss_w = 1.0
shp_loss_w = 0.0

# EXP
note = ''

log_path = './log'      # 日记地址

# FINETUNE or RESUME MODEL PATH
finetune = None
resume = None

# TEST
test_model = None
test_work_dir = './tmp'  #缓存地址

num_lanes = 4

训练与测试

# 训练
python train.py configs/path_to_your_config   # 例如我使用的是tuismple 就是python train.py configs/tuismple.py
 # 测试
 python test.py configs/tuismple.py --test_model path_to_culane_18.pth --test_work_dir ./tmp
 # 例如我的模型是直接放在代码主目录,给的路径为/tusimple_18.pth即可

参考模型的计算量和参数量

以tusimple为例,在eval_wrapper.py中的134行(即net.eval())下面,加入下面的代码即可

from thop import profile, clever_format
x = torch.zeros((1, 3, 288, 800)).cuda() + 1
macs, params = profile(net, inputs=(x,))
macs, params = clever_format([macs, params], "%.3f")
print('MACs: {}'.format(macs))
print('Params: {}'.format(params))

你可能感兴趣的:(论文,深度学习,计算机视觉,目标检测)