如何使用Ultra-fast-lane-detection训练自己的数据(1)

如何使用Ultra-fast-lane-detection训练自己的数据-1

  • Ultra-Fast-Lane-Detection代码简读
    • 修改网络输入大小
    • 修改backbone

Ultra-Fast-Lane-Detection代码简读

Ultra-Fast-Lane-Detection源码链接
该代码中的配置如下:

  1. configs/culane.py or tusimple.py:
    1.1. 两个文件基本配置一样 ,区分在dataset=‘CULane’ or ‘Tusimple’。这两个参数的不同直接影响的是选取的行坐标的长度和数值(row_anchor,该参数在./data/constant.py中定义。Tusimple数据的分辨率较大,选取的点较多,为56点,CULane数据为18个点); griding_num = 200 代表将800宽分成200份。
    1.2. data_root赋值为数据的path(如./CULane/)
    1.3 其他参数就取决于个人训练的设置了。
  2. data/constant.py:记载了网络训练时筛选出的固定行(该数据的数值是针对网络输入shape确定的,如原网络中的输入是(1,3,288,800), row_anchor的数值就是针对288设定的)
  3. data/dataloader.py:加载图片并完成图片初始化的过程。
  4. model/backbone.py: 定义网络的backbone函数,如原码使用的时resnet系列,具体使用哪几层网络,如果联系都是在这里定义的。
  5. model/model.py:车道线检测网络定义,包含实例分割部分(parsingNet)。
    需要注意的文件基本上就以后几个文件。
    如果只想检测的训练一下作者的源码网络,只需要选定1中的数据类别,及具体的数据地址就行,其他的不需要修改直接运行
    提醒:平台配置一定要按照requirments.txt安装,低版本的环境可能运行不了。

修改网络输入大小

源码800*288的大小对于嵌入式平台而言,还是有点大,如何将shape修改成自己想要的大小呢?步骤如下:

  1. 确定输入shape ,以(1,3,128,384) 为例;
  2. 修改data/constan.py中_row_anchor的数值* ,_row_anchor数值是根据网络输入图片的高度确定的,但我们设置h=128时,如果你不修改_row_anchors的长度(长度也可以修改,看个人决定,但过少会影像检测的性能),只需要让原有的数值都×128/288即可,也可以自行设置,但不得大于127;
  3. config/culane.py中修改griding_num:可以根据需要修改,也可以按照原来的代码等比例缩减,如800分200份,384分96份(384/4),以下假设修改为griding_num=96;
  4. 修改 data/dataloader.py中get_train_loader(),get_test_loader()函数 ,将原有的(288,800)全部改为(128,384);segment_transform中的(36,100)改为(128/8, 384/8)=(16,48),这个数据对应网络哦中aux部分的输出维度;其中cls_num_per_lane代表的就是*_row_anchor的个数,根据自己的修改填写。
  5. 修改 **data/dataset.py中函数def _get_index(self, label)**中所有的288为128(共两处);
  6. 修改model/model.py中的class parsingNet(torch.nn.Module):部分首先修改forward()函数里的fea = self.pool(fea).view(-1, 1800)为
# An highlighted block
fea_tmp = self.pool(fea)  #fea的shape为(?,8, 4, 12)
b,c,h,w = fea_tmp.shape  #fea_tmp shape为(?,8, 4, 12)
fea = fea_tmp.view(-1, c*w*h)

然后修改def__init__()函数中的

self.cls = torch.nn.Sequential(
            torch.nn.Linear(1800, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, self.total_dim),
        )

self.cls = torch.nn.Sequential(
            torch.nn.Linear(384, 2048),  #384 = 8 * 4 * 12
            torch.nn.ReLU(),
            torch.nn.Linear(2048, self.total_dim),
        )

以上,在不改变model仅改变图片shape的情况下,所需要的修改就只有上面的部分。

修改backbone

  1. 在model/backbone.py中定义自己的网络 ,在class resnet()中调用自己定义的接口,并根据该接口设置调用对应的feature层;且需要详细记录forward(self,x)中输出的x2, x3, x4的shape
  2. 修改model/model.py ,根据x2,x3,x4的shape修改如下数据
self.aux_header2 = torch.nn.Sequential(
                conv_bn_relu(***[1]***, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1),
                conv_bn_relu(128,128,3,padding=1),
                conv_bn_relu(128,128,3,padding=1),
                conv_bn_relu(128,128,3,padding=1),
            )
            self.aux_header3 = torch.nn.Sequential(
                conv_bn_relu(***[2],*** 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(1024, 128, kernel_size=3, stride=1, padding=1),
                conv_bn_relu(128,128,3,padding=1),
                conv_bn_relu(128,128,3,padding=1),
            )
            self.aux_header4 = torch.nn.Sequential(
                conv_bn_relu(***[3]***, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(2048, 128, kernel_size=3, stride=1, padding=1),
                conv_bn_relu(128,128,3,padding=1),
            )

修改***[1]***为x2, x3, x4对应的c通道数值即可,
2. 注意修改**data/dataloader.py中get_train_loader()**中的segment_transform部分的shape为aux_seg的(h,w)哟!

以上!

Android端移植下次分享!

你可能感兴趣的:(深度学习,网络)