【庖丁解牛】从零实现FCOS(终):FCOS的训练与测试表现

文章目录

  • FCOS的训练方式
  • FCOS论文中提出的各项改进措施
  • FCOS的训练和测试
  • 训练结果

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

FCOS的训练方式

在FCOS论文(https://arxiv.org/pdf/1904.01355.pdf)中提到了在COCO数据集上的训练方法,和RetinaNet的训练方法是一致的:使用momentum=0.9,weight_decay=0.0001的SGD优化器,batch_size=16且使用跨卡同步BN。一共迭代90000次(大约12个epoch),初始学习率为0.01,在60000和80000次分别将学习率除以10。
在我的训练代码中,仍然使用Adam优化器来进行优化。在实验中发现,FCOS的收敛速度明显要比RetinaNet要慢,因此对于FCOS,需要训练24个epoch。
为什么FCOS的收敛速度比RetinaNet要慢?
对于目标检测任务,样本指的是Anchor。FCOS可以看成feature map上每个位置只有一个Anchor的特殊形式,而RetinaNet在feature map上每个位置有9个Anchor。去除掉RetinaNet中被忽略的Anchor样本,对于同样的输入图片,RetinaNet的Anchor样本数量大约是FCOS的5到6倍。样本少也就意味着监督信息变少,因此FCOS的收敛速度要慢一些。事实上,在DETR论文中也可以观察到这种现象,对于每张图片,DETR只产生100个样本,远远少于Faster rcnn中Anchor样本的数量,因此DETR要训练500个epoch才能达到Faster rcnn训练9x(108个epoch)时的性能表现。

FCOS论文中提出的各项改进措施

centerness heads与回归heads共用:
目前的实现就是这样的。

分类heads和回归heads加上GN:
即在heads前四层卷积层之后都加上Group Normlization。代码如下:

layers.append(nn.GroupNorm(32,inplanes))

我已经在自己实现的FCOS上进行了验证,GN可以稳定涨点,但是由于TNN,NCNN都不支持GN算子,因此目前实现的FCOS中没有加入GN。

GIoU:
已经用在回归loss上。

center sampling:
在标注框中,物体一般不会占据框中的100%面积,因此总会有一部分背景在框中。当使用FCOS进行ground truth分配时,有些点可能会在框的内部边缘,但是该位置可能是背景,这些点在训练过程中会比较难收敛。center sampling就是当点在框内中心部分更小的一个区域内才将该点设为正样本,这样可以排除一部分实际上在背景上的点。但是这种方法会多一个超参数来调整区域大小,当更换数据集时,这个超参数需要重新调整,不太方便。我认为如果数据集中同时标注了分割标签时,其实直接判断点是否在分割标签围成的区域内更好,这样对于任何一个数据集都不需要调整超参数。

FCOS的训练和测试

FCOS的训练和测试代码与RetinaNet完全一样,只是多了一个centerness loss。使用RetinaNet的训练代码稍作修改即可。
config.py文件如下:

import sys
import os
import argparse
import random
import shutil
import time
import warnings
import json

BASE_DIR = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')

import numpy as np
from thop import profile
from thop import clever_format
import apex
from apex import amp
from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import Config
from public.detection.dataset.cocodataset import COCODataPrefetcher, collater
from public.detection.models.loss import FCOSLoss
from public.detection.models.decode import FCOSDecoder
from public.detection.models import fcos
from public.imagenet.utils import get_logger
from pycocotools.cocoeval import COCOeval


def parse_args():
    parser = argparse.ArgumentParser(
        description='PyTorch COCO Detection Distributed Training')
    parser.add_argument('--network',
                        type=str,
                        default=Config.network,
                        help='name of network')
    parser.add_argument('--lr',
                        type=float,
                        default=Config.lr,
                        help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=Config.epochs,
                        help='num of training epochs')
    parser.add_argument('--per_node_batch_size',
                        type=int,
                        default=Config.per_node_batch_size,
                        help='per_node batch size')
    parser.add_argument('--pretrained',
                        type=bool,
                        default=Config.pretrained,
                        help='load pretrained model params or not')
    parser.add_argument('--num_classes',
                        type=int,
                        default=Config.num_classes,
                        help='model classification num')
    parser.add_argument('--input_image_size',
                        type=int,
                        default=Config.input_image_size,
                        help='input image size')
    parser.add_argument('--num_workers',
                        type=int,
                        default=Config.num_workers,
                        help='number of worker to load data')
    parser.add_argument('--resume',
                        type=str,
                        default=Config.resume,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkpoints',
                        type=str,
                        default=Config.checkpoint_path,
                        help='path for saving trained models')
    parser.add_argument('--log',
                        type=str,
                        default=Config.log,
                        help='path to save log')
    parser.add_argument('--evaluate',
                        type=str,
                        default=Config.evaluate,
                        help='path for evaluate model')
    parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
    parser.add_argument('--print_interval',
                        type=bool,
                        default=Config.print_interval,
                        help='print interval')
    parser.add_argument('--apex',
                        type=bool,
                        default=Config.apex,
                        help='use apex or not')
    parser.add_argument('--sync_bn',
                        type=bool,
                        default=Config.sync_bn,
                        help='use sync bn or not')
    parser.add_argument('--local_rank',
                        type=int,
                        default=0,
                        help='LOCAL_PROCESS_RANK')

    return parser.parse_args()


def validate(val_dataset, model, decoder):
    model = model.module
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        all_eval_result = evaluate_coco(val_dataset, model, decoder)

    return all_eval_result


def evaluate_coco(val_dataset, model, decoder):
    results, image_ids = [], []
    for index in range(len(val_dataset)):
        data = val_dataset[index]
        scale = data['scale']
        cls_heads, reg_heads, center_heads, batch_positions = model(
            data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0))
        scores, classes, boxes = decoder(cls_heads, reg_heads, center_heads,
                                         batch_positions)
        scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
        boxes /= scale

        # make sure decode batch_size=1
        # scores shape:[1,max_detection_num]
        # classes shape:[1,max_detection_num]
        # bboxes shape[1,max_detection_num,4]
        assert scores.shape[0] == 1

        scores = scores.squeeze(0)
        classes = classes.squeeze(0)
        boxes = boxes.squeeze(0)

        # for coco_eval,we need [x_min,y_min,w,h] format pred boxes
        boxes[:, 2:] -= boxes[:, :2]

        for object_score, object_class, object_box in zip(
                scores, classes, boxes):
            object_score = float(object_score)
            object_class = int(object_class)
            object_box = object_box.tolist()
            if object_class == -1:
                break

            image_result = {
                'image_id':
                val_dataset.image_ids[index],
                'category_id':
                val_dataset.find_category_id_from_coco_label(object_class),
                'score':
                object_score,
                'bbox':
                object_box,
            }
            results.append(image_result)

        image_ids.append(val_dataset.image_ids[index])

        print('{}/{}'.format(index, len(val_dataset)), end='\r')

    if not len(results):
        print("No target detected in test set images")
        return

    json.dump(results,
              open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
              indent=4)

    # load results in COCO evaluation tool
    coco_true = val_dataset.coco
    coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
        val_dataset.set_name))

    coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
    coco_eval.params.imgIds = image_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    all_eval_result = coco_eval.stats

    return all_eval_result


def main():
    args = parse_args()
    global local_rank
    local_rank = args.local_rank
    if local_rank == 0:
        global logger
        logger = get_logger(__name__, args.log)

    torch.cuda.empty_cache()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')
    global gpus_num
    gpus_num = torch.cuda.device_count()
    if local_rank == 0:
        logger.info(f'use {gpus_num} gpus')
        logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    if local_rank == 0:
        logger.info('start loading data')
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        Config.train_dataset, shuffle=True)
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.per_node_batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=collater,
                              sampler=train_sampler)
    if local_rank == 0:
        logger.info('finish loading data')

    model = fcos.__dict__[args.network](**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    for name, param in model.named_parameters():
        if local_rank == 0:
            logger.info(f"{name},{param.requires_grad}")

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    if local_rank == 0:
        logger.info(
            f"model: '{args.network}', flops: {flops}, params: {params}")

    criterion = FCOSLoss(image_w=args.input_image_size,
                         image_h=args.input_image_size).cuda()
    decoder = FCOSDecoder(image_w=args.input_image_size,
                          image_h=args.input_image_size).cuda()

    model = model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    if args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if args.apex:
        amp.register_float_function(torch, 'sigmoid')
        amp.register_float_function(torch, 'softmax')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = apex.parallel.DistributedDataParallel(model,
                                                      delay_allreduce=True)
        if args.sync_bn:
            model = apex.parallel.convert_syncbn_model(model)
    else:
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[local_rank],
                                                    output_device=local_rank)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            if local_rank == 0:
                logger.exception(
                    '{} is not a file, please check it again'.format(
                        args.resume))
            sys.exit(-1)
        if local_rank == 0:
            logger.info('start only evaluating')
            logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        if local_rank == 0:
            logger.info(f"start eval.")
            all_eval_result = validate(Config.val_dataset, model, decoder)
            logger.info(f"eval done.")
            if all_eval_result is not None:
                logger.info(
                    f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                )

        return

    best_map = 0.0
    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        if local_rank == 0:
            logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        best_map = checkpoint['best_map']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        if local_rank == 0:
            logger.info(
                f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
                f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}, center_ness_loss: {checkpoint['center_ness_loss']:2f}"
            )

    if local_rank == 0:
        if not os.path.exists(args.checkpoints):
            os.makedirs(args.checkpoints)

    if local_rank == 0:
        logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        train_sampler.set_epoch(epoch)
        cls_losses, reg_losses, center_ness_losses, losses = train(
            train_loader, model, criterion, optimizer, scheduler, epoch, args)

        if local_rank == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, center_ness_loss: {center_ness_losses:.2f}, loss: {losses:.2f}"
            )

        if epoch % 5 == 0 or epoch == args.epochs:
            if local_rank == 0:
                logger.info(f"start eval.")
                all_eval_result = validate(Config.val_dataset, model, decoder)
                logger.info(f"eval done.")
                if all_eval_result is not None:
                    logger.info(
                        f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                    )
                    if all_eval_result[0] > best_map:
                        torch.save(model.module.state_dict(),
                                   os.path.join(args.checkpoints, "best.pth"))
                        best_map = all_eval_result[0]
        if local_rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'best_map': best_map,
                    'cls_loss': cls_losses,
                    'reg_loss': reg_losses,
                    'center_ness_loss': center_ness_losses,
                    'loss': losses,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, os.path.join(args.checkpoints, 'latest.pth'))

    if local_rank == 0:
        logger.info(f"finish training, best_map: {best_map:.3f}")
    training_time = (time.time() - start_time) / 3600
    if local_rank == 0:
        logger.info(
            f"finish training, total training time: {training_time:.2f} hours")


def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):
    cls_losses, reg_losses, center_ness_losses, losses = [], [], [], []

    # switch to train mode
    model.train()

    iters = len(train_loader.dataset) // (args.per_node_batch_size * gpus_num)
    prefetcher = COCODataPrefetcher(train_loader)
    images, annotations = prefetcher.next()
    iter_index = 1

    while images is not None:
        images, annotations = images.cuda().float(), annotations.cuda()
        cls_heads, reg_heads, center_heads, batch_positions = model(images)
        cls_loss, reg_loss, center_ness_loss = criterion(
            cls_heads, reg_heads, center_heads, batch_positions, annotations)
        loss = cls_loss + reg_loss + center_ness_loss
        if cls_loss == 0.0 or reg_loss == 0.0:
            optimizer.zero_grad()
            continue

        if args.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        optimizer.zero_grad()

        cls_losses.append(cls_loss.item())
        reg_losses.append(reg_loss.item())
        center_ness_losses.append(center_ness_loss.item())
        losses.append(loss.item())

        images, annotations = prefetcher.next()

        if local_rank == 0 and iter_index % args.print_interval == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, center_ness_loss: {center_ness_loss.item():.2f}, loss_total: {loss.item():.2f}"
            )

        iter_index += 1

    scheduler.step(np.mean(losses))

    return np.mean(cls_losses), np.mean(reg_losses), np.mean(
        center_ness_losses), np.mean(losses)


if __name__ == '__main__':
    main()

训练结果

下面的实验结果输入均为667大小,相当于论文中的resize=400。关于为什么等价可以看【庖丁解牛】从零实现RetinaNet(终)文章中的解释。mAP为COCOeval stats[0]值,mAR为COCOeval stats[8]值。

Network batch gpu-num apex syncbn epoch5-mAP-mAR-loss epoch10-mAP-mAR-loss epoch12-mAP-mAR-loss epoch15-mAP-mAR-loss epoch20-mAP-mAR-loss epoch24-mAP-mAR-loss
ResNet50-FCOS-myresize667-fastdecode 32 2 yes no 0.162,0.289,1.31 0.226,0.342,1.21 0.248,0.370,1.20 0.217,0.343,1.17 0.282,0.409,1.14 0.286,0.409,1.12
ResNet101-FCOS-myresize667-fastdecode 24 2 yes no 0.206,0.325,1.29 0.237,0.359,1.20 0.263,0.380,1.18 0.277,0.400,1.15 0.260,0.385,1.13 0.291,0.416,1.10

我训练的resnet50_FCOS模型mAP要略低于同样大小输入的ResNet50-RetinaNet(低0.7个百分点),这可能是因为没有使用Group Normlization和center sample的原因。但是FCOS模型的mAR指标高于RetinaNet,这表明centerness分支对提升mAR是有效的。

你可能感兴趣的:(【庖丁解牛】从零实现FCOS(终):FCOS的训练与测试表现)