rock带你读CornerNet-lite系列源码(一)

文章目录

    • 前言
    • 项目架构
    • 训练
    • configs参数解析:
    • core 文件:
    • models 核心解读:
      • 数据流 training:

前言

CornerNet-lite系列模型分Cornernet,Cornernet-Saccade,Cornernet-Squeese三个网络,后2个网络是Cornernet的改进版本,虽然说现在Anchor-free系列的FCOS、centerNet网络性能优于Cornernet,但是学习前者的源码还是很有必要的,况且Conernet的代码风格很清新。
论文地址:https://arxiv.org/abs/1904.08900
代码:https://github.com/princeton-vl/CornerNet-Lite

项目架构

CornerNet-Lite-master
-----configs   #数据集参数配置文件
     cornernet-saccade.json
-----core  
------------dbs  #数据预处理,COCO,voc格式
------------external  #对检测的box处理 ,nms
------------models  # 模型架构定义,loss,返回模型得到的 Out
------------nnet  #py_factory.py 模型启动器 ,类似于solver
------------sample  # ground truth 数据图像encode,汇总成Target
------------test  #测试
------------utils # 部分组件单元
-----data  #需要手动创建,放voc 和coco数据集
-----demo.py
-----evaluate.py  #测试
-----train.py   

训练

项目训练没有太多注意的地方,参考ReadME即可,注意一点当只有一个GPU时,需要将config下的json文件参数: batch_size=chunk_sizes=X ,X不要太大就可以,一般选15。

configs参数解析:

以 cornernet-saccade.json为例:

{
    "system": {
        "dataset": "VOC",  #数据集格式
        "batch_size": 15,  #batch
        "sampling_function": "cornernet_saccade",

        "train_split": "trainval"  #划分数据集
        "val_split": "minival",

        "learning_rate": 0.00025,  #学习率
        "decay_rate": 10,  #学习率衰减的参数, 为了迭代次数越多,将学习率降低
                 # 参考train.py learning_rate /= (decay_rate ** (start_iter // stepsize))
        "val_iter": 100,    

        "opt_algo": "adam",
        "prefetch_size": 5,   #作者采用多线程训练 5个为一个线程队列

        "max_iter": 500000,
        "stepsize": 450000,
        "snapshot": 5000,

        "chunk_sizes": [15]   #一块gpu处理的图像数
    },
    
    "db": {
        "rand_scale_min": 0.5,  #数据预处理的参数
        "rand_scale_max": 1.1,
        "rand_scale_step": 0.1,
        "rand_scales": null,

        "rand_full_crop": true,
        "gaussian_bump": true,
        "gaussian_iou": 0.5,

        "min_scale": 16,
        "view_sizes": [],

        "height_mult": 31,  #代码里没用到
        "width_mult": 31,

        "input_size": [255, 255],  #输入图像统一后尺寸
        "output_sizes": [[64, 64]], #输出特征图尺寸

        "att_max_crops": 30,  
        "att_scales": [[1, 2, 4]],
        "att_thresholds": [0.3],

        "top_k": 12,  #对检测出的box(或dets)选取top_k个
        "num_dets": 12,
        "categories": 1,  #类别
        "ae_threshold": 0.3,
        "nms_threshold": 0.5,

        "max_per_image": 100
    }
}

core 文件:

dbs:
dbs下只给出了COCO数据集的训练格式, voc的训练格式博主放到了github:https://github.com/huangzicheng/CornerNet-Lite,COCO文件主要返回detections, eval_ids,detection包含物体的category和box, eval_ids 保存每张图像的id,这个不细读了,如何修改参考博主给出的voc版本就可以加深理解。

import os
import json
import numpy as np

from .detection import DETECTION
from ..paths import get_file_path

# COCO bounding boxes are 0-indexed

class COCO(DETECTION):
    def __init__(self, db_config, split=None, sys_config=None):
        assert split is None or sys_config is not None
        super(COCO, self).__init__(db_config)

        self._mean    = np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32)
        self._std     = np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32)
        self._eig_val = np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32)
        self._eig_vec = np.array([
            [-0.58752847, -0.69563484, 0.41340352],
            [-0.5832747, 0.00994535, -0.81221408],
            [-0.56089297, 0.71832671, 0.41158938]
        ], dtype=np.float32)

        self._coco_cls_ids = [
            1,
            2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13,
            14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
            24, 25, 27, 28, 31, 32, 33, 34, 35, 36,
            37, 38, 39, 40, 41, 42, 43, 44, 46, 47,
            48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
            58, 59, 60, 61, 62, 63, 64, 65, 67, 70,
            72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
            82, 84, 85, 86, 87, 88, 89, 90
        ]

        self._coco_cls_names = [
            'person',
            'bicycle', 'car', 'motorcycle', 'airplane',
            'bus', 'train', 'truck', 'boat', 'traffic light',
            'fire hydrant', 'stop sign', 'parking meter', 'bench',
            'bird', 'cat', 'dog', 'horse','sheep', 'cow', 'elephant',
            'bear', 'zebra','giraffe', 'backpack', 'umbrella',
            'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
            'snowboard','sports ball', 'kite', 'baseball bat',
            'baseball glove', 'skateboard', 'surfboard',
            'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
            'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
            'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
            'donut', 'cake', 'chair', 'couch', 'potted plant',
            'bed', 'dining table', 'toilet', 'tv', 'laptop',
            'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
            'oven', 'toaster', 'sink', 'refrigerator', 'book',
            'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
            'toothbrush'
        ]

        self._cls2coco  = {ind + 1: coco_id for ind, coco_id in enumerate(self._coco_cls_ids)}
        self._coco2cls  = {coco_id: cls_id for cls_id, coco_id in self._cls2coco.items()}
        self._coco2name = {cls_id: cls_name for cls_id, cls_name in zip(self._coco_cls_ids, self._coco_cls_names)}
        self._name2coco = {cls_name: cls_id for cls_name, cls_id in self._coco2name.items()}

        if split is not None:
            coco_dir = os.path.join(sys_config.data_dir, "coco")
            #coco_dir='/media/diskData/huanglong_data/coco'
            self._split     = {
                "trainval": "train2017",
                "minival":  "val2017",
                "testdev":  "val2017"
            }[split]
            self._data_dir  = os.path.join(coco_dir, self._split)
            self._anno_file = os.path.join(coco_dir, "annotations", "instances_{}.json".format(self._split))

            self._detections, self._eval_ids = self._load_coco_annos()
            self._image_ids = list(self._detections.keys())
            self._db_inds   = np.arange(len(self._image_ids))

    def _load_coco_annos(self):
        from pycocotools.coco import COCO

        coco = COCO(self._anno_file)
        self._coco = coco

        class_ids = coco.getCatIds()
        image_ids = coco.getImgIds()
        
        eval_ids   = {}
        detections = {}
        for image_id in image_ids:
            image = coco.loadImgs(image_id)[0]
            dets  = []
            
            eval_ids[image["file_name"]] = image_id
            for class_id in class_ids:
                annotation_ids = coco.getAnnIds(imgIds=image["id"], catIds=class_id)
                annotations    = coco.loadAnns(annotation_ids)
                category       = self._coco2cls[class_id]
                for annotation in annotations:
                    det     = annotation["bbox"] + [category]
                    det[2] += det[0]
                    det[3] += det[1]
                    dets.append(det)

            file_name = image["file_name"]
            if len(dets) == 0:
                detections[file_name] = np.zeros((0, 5), dtype=np.float32)
            else:
                detections[file_name] = np.array(dets, dtype=np.float32)
        return detections, eval_ids

    def image_path(self, ind):
        if self._data_dir is None:
            raise ValueError("Data directory is not set")

        db_ind    = self._db_inds[ind]
        file_name = self._image_ids[db_ind]
        return os.path.join(self._data_dir, file_name)

    def detections(self, ind):
        db_ind    = self._db_inds[ind]
        file_name = self._image_ids[db_ind]
        return self._detections[file_name].copy()

    def cls2name(self, cls):
        coco = self._cls2coco[cls]
        return self._coco2name[coco]

    def _to_float(self, x):
        return float("{:.2f}".format(x))

    def convert_to_coco(self, all_bboxes):
        detections = []
        for image_id in all_bboxes:
            coco_id = self._eval_ids[image_id]
            for cls_ind in all_bboxes[image_id]:
                category_id = self._cls2coco[cls_ind]
                for bbox in all_bboxes[image_id][cls_ind]:
                    bbox[2] -= bbox[0]
                    bbox[3] -= bbox[1]

                    score = bbox[4]
                    bbox  = list(map(self._to_float, bbox[0:4]))

                    detection = {
                        "image_id": coco_id,
                        "category_id": category_id,
                        "bbox": bbox,
                        "score": float("{:.2f}".format(score))
                    }

                    detections.append(detection)
        return detections

    def evaluate(self, result_json, cls_ids, image_ids):
        from pycocotools.cocoeval import COCOeval

        if self._split == "testdev":
            return None

        coco = self._coco

        eval_ids = [self._eval_ids[image_id] for image_id in image_ids]
        cat_ids  = [self._cls2coco[cls_id] for cls_id in cls_ids]

        coco_dets = coco.loadRes(result_json)
        coco_eval = COCOeval(coco, coco_dets, "bbox")
        coco_eval.params.imgIds = eval_ids
        coco_eval.params.catIds = cat_ids
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()
        return coco_eval.stats[0], coco_eval.stats[12:]

models 核心解读:

models文件下有cornernet的三个网络架构文件,py——utils下存放一些modules组件。这里的model很多,组件复杂,如何能清楚model的构成和数据流向,我先从模型训练开始讲,通过数据流的走向,逐渐过度到网络模型的每个组件,这样会比较清楚的掌握整个项目code。

数据流 training:

见 train.py 文件 ,训练CornerNet_Squeeze命令:

 python  /homexxx/CornerNet-Lite-master/train.py    CornerNet_Squeeze

总的调用流程是:

rock带你读CornerNet-lite系列源码(一)_第1张图片

#!/usr/bin/env python
import os
import json
import torch
import numpy as np
import queue
import pprint
import random 
import argparse
import importlib
import threading
import traceback
import torch.distributed as dist
import torch.multiprocessing as mp

from tqdm import tqdm
from torch.multiprocessing import Process, Queue, Pool

from core.dbs import datasets
from core.utils import stdout_to_tqdm
from core.config import SystemConfig
from core.sample import data_sampling_func
from core.nnet.py_factory import NetworkFactory


torch.backends.cudnn.enabled   = True
torch.backends.cudnn.benchmark = True

def parse_args():
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument("cfg_file", help="config file", type=str)
    parser.add_argument("--iter", dest="start_iter",
                        help="train at iteration i",
                        default=0, type=int)
    parser.add_argument("--workers", default=4, type=int)
    parser.add_argument("--initialize", action="store_true")

    parser.add_argument("--distributed", action="store_true")   #action=store_true 表示如果训练加参数 --distributed 该值为true ,不加为False
    parser.add_argument("--world-size", default=-1, type=int,
                        help="number of nodes of distributed training")
    parser.add_argument("--rank", default=0, type=int,
                        help="node rank for distributed training")
    parser.add_argument("--dist-url", default=None, type=str,
                        help="url used to set up distributed training")
    parser.add_argument("--dist-backend", default="nccl", type=str)

    args = parser.parse_args()
    return args

def prefetch_data(system_config, db, queue, sample_data, data_aug):
    ind = 0
    print("start prefetching data...")
    np.random.seed(os.getpid())
    while True:
        try:
            data, ind = sample_data(system_config, db, ind, data_aug=data_aug)
            queue.put(data)
        except Exception as e:
            traceback.print_exc()
            raise e

def _pin_memory(ts):
    if type(ts) is list:
        return [t.pin_memory() for t in ts]
    return ts.pin_memory()

def pin_memory(data_queue, pinned_data_queue, sema):  #training_queue, pinned_training_queue, training_pin_semaphore
    while True:
        data = data_queue.get()

        data["xs"] = [_pin_memory(x) for x in data["xs"]]
        data["ys"] = [_pin_memory(y) for y in data["ys"]]

        pinned_data_queue.put(data)

        if sema.acquire(blocking=False):
            return

def init_parallel_jobs(system_config, dbs, queue, fn, data_aug):
    tasks = [Process(target=prefetch_data, args=(system_config, db, queue, fn, data_aug)) for db in dbs]
    for task in tasks:
        task.daemon = True
        task.start()
    return tasks

def terminate_tasks(tasks):
    for task in tasks:
        task.terminate()

def train(training_dbs, validation_db, system_config, model, args):
    # reading arguments from command
    start_iter  = args.start_iter
    distributed = args.distributed
    world_size  = args.world_size
    initialize  = args.initialize
    gpu         = args.gpu
    rank        = args.rank

    # reading arguments from json file
    batch_size       = system_config.batch_size
    learning_rate    = system_config.learning_rate
    max_iteration    = system_config.max_iter
    pretrained_model = system_config.pretrain
    stepsize         = system_config.stepsize
    snapshot         = system_config.snapshot
    val_iter         = system_config.val_iter
    display          = system_config.display
    decay_rate       = system_config.decay_rate
    stepsize         = system_config.stepsize

    print("Process {}: building model...".format(rank))
    nnet = NetworkFactory(system_config, model, distributed=distributed, gpu=gpu) #初始化下网络
    if initialize:
        nnet.save_params(0)
        exit(0)

    # queues storing data for training
    training_queue   = Queue(system_config.prefetch_size)
    validation_queue = Queue(5)

    # queues storing pinned data for training
    pinned_training_queue   = queue.Queue(system_config.prefetch_size)
    pinned_validation_queue = queue.Queue(5)

    # allocating resources for parallel reading   #下面2段是python多线程操作
    training_tasks = init_parallel_jobs(system_config, training_dbs, training_queue, data_sampling_func, True)
    if val_iter:
        validation_tasks = init_parallel_jobs(system_config, [validation_db], validation_queue, data_sampling_func, False)

    training_pin_semaphore   = threading.Semaphore()
    validation_pin_semaphore = threading.Semaphore()
    training_pin_semaphore.acquire()
    validation_pin_semaphore.acquire()

    training_pin_args   = (training_queue, pinned_training_queue, training_pin_semaphore)
    training_pin_thread = threading.Thread(target=pin_memory, args=training_pin_args)
    training_pin_thread.daemon = True
    training_pin_thread.start()

    validation_pin_args   = (validation_queue, pinned_validation_queue, validation_pin_semaphore)
    validation_pin_thread = threading.Thread(target=pin_memory, args=validation_pin_args)
    validation_pin_thread.daemon = True
    validation_pin_thread.start()

    if pretrained_model is not None:   #这里pretrained_model 在 utils文件的config.py 设定为NONE
        if not os.path.exists(pretrained_model):
            raise ValueError("pretrained model does not exist")
        print("Process {}: loading from pretrained model".format(rank))
        nnet.load_pretrained_params(pretrained_model)

    if start_iter:
        nnet.load_params(start_iter)   #开始iter的数字
        learning_rate /= (decay_rate ** (start_iter // stepsize))
        nnet.set_lr(learning_rate)   #设置网络学习率
        print("Process {}: training starts from iteration {} with learning_rate {}".format(rank, start_iter + 1, learning_rate))
    else:
        nnet.set_lr(learning_rate)

    if rank == 0:
        print("training start...")
    nnet.cuda()
    nnet.train_mode()
    with stdout_to_tqdm() as save_stdout:
        for iteration in tqdm(range(start_iter + 1, max_iteration + 1), file=save_stdout, ncols=80):
            training = pinned_training_queue.get(block=True)
            training_loss = nnet.train(**training)

            if display and iteration % display == 0:
                print("Process {}: training loss at iteration {}: {}".format(rank, iteration, training_loss.item()))
            del training_loss

            if val_iter and validation_db.db_inds.size and iteration % val_iter == 0:
                nnet.eval_mode()
                validation = pinned_validation_queue.get(block=True)
                validation_loss = nnet.validate(**validation)
                print("Process {}: validation loss at iteration {}: {}".format(rank, iteration, validation_loss.item()))
                nnet.train_mode()

            if iteration % snapshot == 0 and rank == 0:
                nnet.save_params(iteration)

            if iteration % stepsize == 0:
                learning_rate /= decay_rate
                nnet.set_lr(learning_rate)

    # sending signal to kill the thread
    training_pin_semaphore.release()
    validation_pin_semaphore.release()

    # terminating data fetching processes
    terminate_tasks(training_tasks)
    terminate_tasks(validation_tasks)

def main(gpu, ngpus_per_node, args):
    args.gpu = gpu           # 多GPU 训练 ,一个GPU自动跳过
    if args.distributed:
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    rank = args.rank

    cfg_file = os.path.join("./configs", args.cfg_file + ".json")   #加载网络各种参数
    with open(cfg_file, "r") as f:
        config = json.load(f)

    config["system"]["snapshot_name"] = args.cfg_file     #模型名CornerNet_Squeeze
    system_config = SystemConfig().update_config(config["system"]) 

    model_file  = "core.models.{}".format(args.cfg_file)  #model的路径名core.models.CornerNet_Squeeze
    model_file  = importlib.import_module(model_file) #import_moudle可以根据路径名导入该CornerNet_Squeeze.py 文件
    model       = model_file.model() #取CornerNet_Squeeze.py文件里面已经构建好的model 类
    # print(model)
    # from thop import profile
    #input = torch.randn(1, 3, 511, 511)
    # flops, params = profile(model, inputs=(input, ))
    
    train_split = system_config.train_split   #训练参数
    val_split   = system_config.val_split

    print("Process {}: loading all datasets...".format(rank))
    dataset = system_config.dataset   #数据集处理的一些参数配置
    workers = args.workers  #线程
    print("Process {}: using {} workers".format(rank, workers))
    training_dbs = [datasets[dataset](config["db"], split=train_split, sys_config=system_config) for _ in range(workers)]  #训练数据
    validation_db = datasets[dataset](config["db"], split=val_split, sys_config=system_config)

    if rank == 0:
        print("system config...")
        pprint.pprint(system_config.full)

        print("db config...")
        pprint.pprint(training_dbs[0].configs)

        print("len of db: {}".format(len(training_dbs[0].db_inds)))
        print("distributed: {}".format(args.distributed))

    train(training_dbs, validation_db, system_config, model, args)

if __name__ == "__main__":
    args = parse_args()

    distributed = args.distributed  
    world_size  = args.world_size

    if distributed and world_size < 0:   #distributed为 False  wordsize=-1  不满足该条件
        raise ValueError("world size must be greater than 0 in distributed training")

    ngpus_per_node  = torch.cuda.device_count()
    if distributed:  #不满足
        args.world_size = ngpus_per_node * args.world_size
        mp.spawn(main, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:  #执行这句
        main(None, ngpus_per_node, args)

你可能感兴趣的:(Pytorch学习)