CenterNet代码解析-ctdet目标检测

CenterNet(Objects as points)开源代码:

https://github.com/xingyizhou/CenterNethttps://github.com/xingyizhou/CenterNet

1、环境安装

在GitHub下载centernet代码,并按照CenterNet中的reademe-Installation部分搭建centernet运行环境

需要注意的是:在原版Installation部分需要安装pytorch0.4(因为用到的DCNV2对应了这个pytorch),根据bilibili up主-Guuuuuu老师儿 讲解的使用Pytorch搭建centernet视频,得知可使用1.4版本的pytorch以及下载github上最新的dcnv2,搭配使用可run centernet。但是我的显卡支持1.10版本以上的,所以1.4的pytorch也不能满足我的需求,为了快速复现centernet,没有在搭配环境上下太多功夫,按照bilibili视频安装了pytorch1.4及替换了DCNV2文件夹(替换以后还需要build等操作具体见视频)。在centernet代码中使用cuda加速部分注释掉使用cpu进行训练,这样和pytorch版本也就没什么关系了(这么说按照原版的安装0.4其实也可以。。。)

ps,视频中讲解运行demo.py 使用命令行运行,为了使用vscode进行调试,其他的配置参数可以加在.lanuch.json文件中

   "args": [
                "ctdet",
                "--demo","../images/19064748793_bb942deea1_k.jpg",
                "--load_model","../models/ctdet_coco_dla_2x.pth"
            ]

2、centernet代码说明

(1)配置文件opts.py

配置使用centernet进行哪种任务,使用哪种数据集,数据集格式使用coco | kitti | coco_hp | pascal。配置主干网络模型使用'res_18 | res_101 | resdcn_18 | resdcn_101 |'dlav0_34 | dla_34 | hourglass等配置。

class opts(object):

def __init__(self):

    self.parser = argparse.ArgumentParser()

    # basic experiment setting

    self.parser.add_argument('task', default='ctdet',

    help='ctdet | ddd | multi_pose | exdet')

    self.parser.add_argument('--dataset', default='coco',

    help='coco | kitti | coco_hp | pascal')
    
    self.parser.add_argument('--exp_id', default='default')

    self.parser.add_argument('--test', action='store_true')

  

(2)制作自己数据集的真值

参考从代码角度分析高效优雅检测模型CenterNet - 作业部落 Cmd Markdown 编辑阅读器

整个真值生成过程代码在src/lib/datasets/sample/ctdet.py,其外部采用的是多继承的方式实现dataset,

def get_dataset(dataset, task):
  class Dataset(dataset_factory[dataset], _sample_factory[task]):
    pass
  return Dataset

a、获取数据集及标签路径等基础信息

其中dataset_factory[dataset]是COCO数据解析类,_sample_factory[task]是目标检测的CTDetDataset,这两个类都是继承至pytorch的Dataset类。

对于python多继承而言,是按照先后顺序继承的,COCO类实现获得数据集路径以及标签json文件的路径及名称,对象类别数量,images的id等。而CTDetDataset才是真正实现了getitem方法,两个类的全部方法和属性合并才得到最终的datalayer层。这样写的好处很明显就是解耦,如果数据格式变了或是说是coco格式,但是内部变量数据值变了,此时就可以仅仅额外提供一个和COCO类一样的py文件即可,而不需要重写CTDetDataset类。但是这样写的缺点也很明显:代码可阅读性降低了很多,而且在CTDetDataset里面强制读取COCO类的属性,是没有代码提示的,因为如果不是多继承的写法,实际运行时候肯定是会报错了,加入了多继承后,子类就可以读取到父类里面的任何一个方法和属性。虽然看起来很优雅,但是这种实现方式不推荐,严重违背迪米特法则

COCO这个类仅仅是为后面的CTDetDataset提供一些数据和属性。

b、获取标签的具体信息如annotations里的信息

Dataloader如何加载json中的annotations

参照【PyTorch深度学习实践】学习笔记 数据集的加载Dataset和DataLoader原理_咯吱咯吱咕嘟咕嘟的博客-CSDN博客 

上面的博客写的很详细了,我这里简单写一下流程

b1、搭建dataloader

DataLoader( dataset = dataset , #dataset 是继承了dataset类之后加载数据集提供路径
            batch_size = 32, #选择batch_size的大小
            shuffle = true,  #增强数据集随机性
            num_workers = 2 ) #多进程读数据

在enumerate(Dataloader)中具体读到json文件中的annotations

for i, data in enumerate(train_loader):

若是没有跳转到dataloader.py文件中需要在lanuch.json文件中添加如下配置信息

"justMyCode": false,

b2、跳转到

def __iter__(self):
	if self.num_workers == 0:
		return _SingleProcessDataLoaderIter(self)
	else:
		return _MultiProcessingDataLoaderIter(self) #进程问题

b3、跳转到

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
	def __init__(self,loader):
		super(_SingleProcessDataLoaderIter,self).__init__(loader)
		assert self.timeout == 0
		assert self.num_workers == 0
		self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset,self.auto_collation, self.collate_fn, self.drop_last)

    def __next__(self):
        index = self._next_index()  # may raise StopIteration
        data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
        if self.pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

    next = __next__  # Python 2 compatibility

b4、跳转到

  def _next_index(self):
        return next(self.sampler_iter)  # may raise StopIteration

b5、跳转到

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]  #调用了dataset,通过一系列的data拼接成一个list;
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

b6、跳转到CTDetDataset类的getitem函数

class CTDetDataset(data.Dataset):
     def __getitem__(self, index):

在这个函数中我们对读取到的json进行制作,制作我们想要的真值。

在我的数据集中,原图为1280,标签也为1280。将每一个json合并到数据集的大json文件时,将图片及标签都转为了512大小(如何转?直接将json文件中的坐标*512/1280)

b6-1、取出anns中的前角点,后角点,chock点,edge点以及是否被占等信息存入joint中

b6-2、将这些点即joint再缩小为128,对这些点画gussian。

b6-3、将得到的gussian,写入图片

你可能感兴趣的:(深度学习)