所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
在FCOS论文(https://arxiv.org/pdf/1904.01355.pdf)中提到了在COCO数据集上的训练方法,和RetinaNet的训练方法是一致的:使用momentum=0.9,weight_decay=0.0001的SGD优化器,batch_size=16且使用跨卡同步BN。一共迭代90000次(大约12个epoch),初始学习率为0.01,在60000和80000次分别将学习率除以10。
在我的训练代码中,仍然使用Adam优化器来进行优化。在实验中发现,FCOS的收敛速度明显要比RetinaNet要慢,因此对于FCOS,需要训练24个epoch。
为什么FCOS的收敛速度比RetinaNet要慢?
对于目标检测任务,样本指的是Anchor。FCOS可以看成feature map上每个位置只有一个Anchor的特殊形式,而RetinaNet在feature map上每个位置有9个Anchor。去除掉RetinaNet中被忽略的Anchor样本,对于同样的输入图片,RetinaNet的Anchor样本数量大约是FCOS的5到6倍。样本少也就意味着监督信息变少,因此FCOS的收敛速度要慢一些。事实上,在DETR论文中也可以观察到这种现象,对于每张图片,DETR只产生100个样本,远远少于Faster rcnn中Anchor样本的数量,因此DETR要训练500个epoch才能达到Faster rcnn训练9x(108个epoch)时的性能表现。
centerness heads与回归heads共用:
目前的实现就是这样的。
分类heads和回归heads加上GN:
即在heads前四层卷积层之后都加上Group Normlization。代码如下:
layers.append(nn.GroupNorm(32,inplanes))
我已经在自己实现的FCOS上进行了验证,GN可以稳定涨点,但是由于TNN,NCNN都不支持GN算子,因此目前实现的FCOS中没有加入GN。
GIoU:
已经用在回归loss上。
center sampling:
在标注框中,物体一般不会占据框中的100%面积,因此总会有一部分背景在框中。当使用FCOS进行ground truth分配时,有些点可能会在框的内部边缘,但是该位置可能是背景,这些点在训练过程中会比较难收敛。center sampling就是当点在框内中心部分更小的一个区域内才将该点设为正样本,这样可以排除一部分实际上在背景上的点。但是这种方法会多一个超参数来调整区域大小,当更换数据集时,这个超参数需要重新调整,不太方便。我认为如果数据集中同时标注了分割标签时,其实直接判断点是否在分割标签围成的区域内更好,这样对于任何一个数据集都不需要调整超参数。
FCOS的训练和测试代码与RetinaNet完全一样,只是多了一个centerness loss。使用RetinaNet的训练代码稍作修改即可。
config.py文件如下:
import sys
import os
import argparse
import random
import shutil
import time
import warnings
import json
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')
import numpy as np
from thop import profile
from thop import clever_format
import apex
from apex import amp
from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import Config
from public.detection.dataset.cocodataset import COCODataPrefetcher, collater
from public.detection.models.loss import FCOSLoss
from public.detection.models.decode import FCOSDecoder
from public.detection.models import fcos
from public.imagenet.utils import get_logger
from pycocotools.cocoeval import COCOeval
def parse_args():
parser = argparse.ArgumentParser(
description='PyTorch COCO Detection Distributed Training')
parser.add_argument('--network',
type=str,
default=Config.network,
help='name of network')
parser.add_argument('--lr',
type=float,
default=Config.lr,
help='learning rate')
parser.add_argument('--epochs',
type=int,
default=Config.epochs,
help='num of training epochs')
parser.add_argument('--per_node_batch_size',
type=int,
default=Config.per_node_batch_size,
help='per_node batch size')
parser.add_argument('--pretrained',
type=bool,
default=Config.pretrained,
help='load pretrained model params or not')
parser.add_argument('--num_classes',
type=int,
default=Config.num_classes,
help='model classification num')
parser.add_argument('--input_image_size',
type=int,
default=Config.input_image_size,
help='input image size')
parser.add_argument('--num_workers',
type=int,
default=Config.num_workers,
help='number of worker to load data')
parser.add_argument('--resume',
type=str,
default=Config.resume,
help='put the path to resuming file if needed')
parser.add_argument('--checkpoints',
type=str,
default=Config.checkpoint_path,
help='path for saving trained models')
parser.add_argument('--log',
type=str,
default=Config.log,
help='path to save log')
parser.add_argument('--evaluate',
type=str,
default=Config.evaluate,
help='path for evaluate model')
parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
parser.add_argument('--print_interval',
type=bool,
default=Config.print_interval,
help='print interval')
parser.add_argument('--apex',
type=bool,
default=Config.apex,
help='use apex or not')
parser.add_argument('--sync_bn',
type=bool,
default=Config.sync_bn,
help='use sync bn or not')
parser.add_argument('--local_rank',
type=int,
default=0,
help='LOCAL_PROCESS_RANK')
return parser.parse_args()
def validate(val_dataset, model, decoder):
model = model.module
# switch to evaluate mode
model.eval()
with torch.no_grad():
all_eval_result = evaluate_coco(val_dataset, model, decoder)
return all_eval_result
def evaluate_coco(val_dataset, model, decoder):
results, image_ids = [], []
for index in range(len(val_dataset)):
data = val_dataset[index]
scale = data['scale']
cls_heads, reg_heads, center_heads, batch_positions = model(
data['img'].cuda().permute(2, 0, 1).float().unsqueeze(dim=0))
scores, classes, boxes = decoder(cls_heads, reg_heads, center_heads,
batch_positions)
scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
boxes /= scale
# make sure decode batch_size=1
# scores shape:[1,max_detection_num]
# classes shape:[1,max_detection_num]
# bboxes shape[1,max_detection_num,4]
assert scores.shape[0] == 1
scores = scores.squeeze(0)
classes = classes.squeeze(0)
boxes = boxes.squeeze(0)
# for coco_eval,we need [x_min,y_min,w,h] format pred boxes
boxes[:, 2:] -= boxes[:, :2]
for object_score, object_class, object_box in zip(
scores, classes, boxes):
object_score = float(object_score)
object_class = int(object_class)
object_box = object_box.tolist()
if object_class == -1:
break
image_result = {
'image_id':
val_dataset.image_ids[index],
'category_id':
val_dataset.find_category_id_from_coco_label(object_class),
'score':
object_score,
'bbox':
object_box,
}
results.append(image_result)
image_ids.append(val_dataset.image_ids[index])
print('{}/{}'.format(index, len(val_dataset)), end='\r')
if not len(results):
print("No target detected in test set images")
return
json.dump(results,
open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
indent=4)
# load results in COCO evaluation tool
coco_true = val_dataset.coco
coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
val_dataset.set_name))
coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
all_eval_result = coco_eval.stats
return all_eval_result
def main():
args = parse_args()
global local_rank
local_rank = args.local_rank
if local_rank == 0:
global logger
logger = get_logger(__name__, args.log)
torch.cuda.empty_cache()
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
global gpus_num
gpus_num = torch.cuda.device_count()
if local_rank == 0:
logger.info(f'use {gpus_num} gpus')
logger.info(f"args: {args}")
cudnn.benchmark = True
cudnn.enabled = True
start_time = time.time()
# dataset and dataloader
if local_rank == 0:
logger.info('start loading data')
train_sampler = torch.utils.data.distributed.DistributedSampler(
Config.train_dataset, shuffle=True)
train_loader = DataLoader(Config.train_dataset,
batch_size=args.per_node_batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collater,
sampler=train_sampler)
if local_rank == 0:
logger.info('finish loading data')
model = fcos.__dict__[args.network](**{
"pretrained": args.pretrained,
"num_classes": args.num_classes,
})
for name, param in model.named_parameters():
if local_rank == 0:
logger.info(f"{name},{param.requires_grad}")
flops_input = torch.randn(1, 3, args.input_image_size,
args.input_image_size)
flops, params = profile(model, inputs=(flops_input, ))
flops, params = clever_format([flops, params], "%.3f")
if local_rank == 0:
logger.info(
f"model: '{args.network}', flops: {flops}, params: {params}")
criterion = FCOSLoss(image_w=args.input_image_size,
image_h=args.input_image_size).cuda()
decoder = FCOSDecoder(image_w=args.input_image_size,
image_h=args.input_image_size).cuda()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
patience=3,
verbose=True)
if args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.apex:
amp.register_float_function(torch, 'sigmoid')
amp.register_float_function(torch, 'softmax')
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
model = apex.parallel.DistributedDataParallel(model,
delay_allreduce=True)
if args.sync_bn:
model = apex.parallel.convert_syncbn_model(model)
else:
model = nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
if args.evaluate:
if not os.path.isfile(args.evaluate):
if local_rank == 0:
logger.exception(
'{} is not a file, please check it again'.format(
args.resume))
sys.exit(-1)
if local_rank == 0:
logger.info('start only evaluating')
logger.info(f"start resuming model from {args.evaluate}")
checkpoint = torch.load(args.evaluate,
map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
if local_rank == 0:
logger.info(f"start eval.")
all_eval_result = validate(Config.val_dataset, model, decoder)
logger.info(f"eval done.")
if all_eval_result is not None:
logger.info(
f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
)
return
best_map = 0.0
start_epoch = 1
# resume training
if os.path.exists(args.resume):
if local_rank == 0:
logger.info(f"start resuming model from {args.resume}")
checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
start_epoch += checkpoint['epoch']
best_map = checkpoint['best_map']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if local_rank == 0:
logger.info(
f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}, center_ness_loss: {checkpoint['center_ness_loss']:2f}"
)
if local_rank == 0:
if not os.path.exists(args.checkpoints):
os.makedirs(args.checkpoints)
if local_rank == 0:
logger.info('start training')
for epoch in range(start_epoch, args.epochs + 1):
train_sampler.set_epoch(epoch)
cls_losses, reg_losses, center_ness_losses, losses = train(
train_loader, model, criterion, optimizer, scheduler, epoch, args)
if local_rank == 0:
logger.info(
f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, center_ness_loss: {center_ness_losses:.2f}, loss: {losses:.2f}"
)
if epoch % 5 == 0 or epoch == args.epochs:
if local_rank == 0:
logger.info(f"start eval.")
all_eval_result = validate(Config.val_dataset, model, decoder)
logger.info(f"eval done.")
if all_eval_result is not None:
logger.info(
f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
)
if all_eval_result[0] > best_map:
torch.save(model.module.state_dict(),
os.path.join(args.checkpoints, "best.pth"))
best_map = all_eval_result[0]
if local_rank == 0:
torch.save(
{
'epoch': epoch,
'best_map': best_map,
'cls_loss': cls_losses,
'reg_loss': reg_losses,
'center_ness_loss': center_ness_losses,
'loss': losses,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, os.path.join(args.checkpoints, 'latest.pth'))
if local_rank == 0:
logger.info(f"finish training, best_map: {best_map:.3f}")
training_time = (time.time() - start_time) / 3600
if local_rank == 0:
logger.info(
f"finish training, total training time: {training_time:.2f} hours")
def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):
cls_losses, reg_losses, center_ness_losses, losses = [], [], [], []
# switch to train mode
model.train()
iters = len(train_loader.dataset) // (args.per_node_batch_size * gpus_num)
prefetcher = COCODataPrefetcher(train_loader)
images, annotations = prefetcher.next()
iter_index = 1
while images is not None:
images, annotations = images.cuda().float(), annotations.cuda()
cls_heads, reg_heads, center_heads, batch_positions = model(images)
cls_loss, reg_loss, center_ness_loss = criterion(
cls_heads, reg_heads, center_heads, batch_positions, annotations)
loss = cls_loss + reg_loss + center_ness_loss
if cls_loss == 0.0 or reg_loss == 0.0:
optimizer.zero_grad()
continue
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
optimizer.zero_grad()
cls_losses.append(cls_loss.item())
reg_losses.append(reg_loss.item())
center_ness_losses.append(center_ness_loss.item())
losses.append(loss.item())
images, annotations = prefetcher.next()
if local_rank == 0 and iter_index % args.print_interval == 0:
logger.info(
f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, center_ness_loss: {center_ness_loss.item():.2f}, loss_total: {loss.item():.2f}"
)
iter_index += 1
scheduler.step(np.mean(losses))
return np.mean(cls_losses), np.mean(reg_losses), np.mean(
center_ness_losses), np.mean(losses)
if __name__ == '__main__':
main()
下面的实验结果输入均为667大小,相当于论文中的resize=400。关于为什么等价可以看【庖丁解牛】从零实现RetinaNet(终)文章中的解释。mAP为COCOeval stats[0]值,mAR为COCOeval stats[8]值。
Network | batch | gpu-num | apex | syncbn | epoch5-mAP-mAR-loss | epoch10-mAP-mAR-loss | epoch12-mAP-mAR-loss | epoch15-mAP-mAR-loss | epoch20-mAP-mAR-loss | epoch24-mAP-mAR-loss |
---|---|---|---|---|---|---|---|---|---|---|
ResNet50-FCOS-myresize667-fastdecode | 32 | 2 | yes | no | 0.162,0.289,1.31 | 0.226,0.342,1.21 | 0.248,0.370,1.20 | 0.217,0.343,1.17 | 0.282,0.409,1.14 | 0.286,0.409,1.12 |
ResNet101-FCOS-myresize667-fastdecode | 24 | 2 | yes | no | 0.206,0.325,1.29 | 0.237,0.359,1.20 | 0.263,0.380,1.18 | 0.277,0.400,1.15 | 0.260,0.385,1.13 | 0.291,0.416,1.10 |
我训练的resnet50_FCOS模型mAP要略低于同样大小输入的ResNet50-RetinaNet(低0.7个百分点),这可能是因为没有使用Group Normlization和center sample的原因。但是FCOS模型的mAR指标高于RetinaNet,这表明centerness分支对提升mAR是有效的。