所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
首先,使用分布式训练时,要多设置一个变量local_rank。local_rank初始值设为0,在进行多卡的分布式训练时,每张卡的local_rank会从0更新为0,1,2,…。
其次,在分布式训练时,DataLoader中的batch_size指的不是总的batch_size,而是分到每张显卡上的batch_size。
然后,我们要使用dist.init_process_group初始化进程组。关于这部分在这里我不详细展开,只提供一种在单机多卡模式下最简单的初始化方法:
dist.init_process_group(backend='nccl', init_method='env://')
在单机多卡服务器上,如果要进行多个分布式训练时(比如有4张卡,有两张卡跑第一个分布式实验训练代码,另外两张卡跑第二个分布式实验训练代码),每个分布式训练实验的train.sh启动代码必须保证master_addr和master_port都不一样。否则在单机多卡服务器上同时跑多个分布式训练代码会报错。
python -m torch.distributed.launch --nproc_per_node=1 --master_addr 127.0.0.1 --master_port 20001 train.py
nproc_per_node即要使用的显卡的数量。
定义model后,需要使用nn.parallel.DistributedDataParallel API包裹model。如果是使用apex,那么也有类似的API:apex.parallel.DistributedDataParallel。
if args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.apex:
amp.register_float_function(torch, 'sigmoid')
amp.register_float_function(torch, 'softmax')
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
model = apex.parallel.DistributedDataParallel(model,
delay_allreduce=True)
if args.sync_bn:
model = apex.parallel.convert_syncbn_model(model)
else:
model = nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
注意使用apex时若要使用sync bn也要使用其自带的apex.parallel.convert_syncbn_model将model中的BN层改为sync BN层。
然后,对于所有logger.info,我们设置为只有local_rank == 0时才写入。否则,你用了几张显卡,Logger就会重复写入几遍。基于同样的原因,我们设置validate时只在local_rank为0的显卡上进行validate。
config.py文件如下:
import os
import sys
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
from public.path import COCO2017_path
from public.detection.dataset.cocodataset import CocoDetection, Resize, RandomFlip, RandomCrop, RandomTranslate
import torchvision.transforms as transforms
import torchvision.datasets as datasets
class Config(object):
log = './log' # Path to save log
checkpoint_path = './checkpoints' # Path to store checkpoint model
resume = './checkpoints/latest.pth' # load checkpoint model
evaluate = None # evaluate model path
train_dataset_path = os.path.join(COCO2017_path, 'images/train2017')
val_dataset_path = os.path.join(COCO2017_path, 'images/val2017')
dataset_annotations_path = os.path.join(COCO2017_path, 'annotations')
network = "resnet50_retinanet"
pretrained = False
num_classes = 80
seed = 0
input_image_size = 600
train_dataset = CocoDetection(image_root_dir=train_dataset_path,
annotation_root_dir=dataset_annotations_path,
set="train2017",
transform=transforms.Compose([
RandomFlip(flip_prob=0.5),
RandomCrop(crop_prob=0.5),
RandomTranslate(translate_prob=0.5),
Resize(resize=input_image_size),
]))
val_dataset = CocoDetection(image_root_dir=val_dataset_path,
annotation_root_dir=dataset_annotations_path,
set="val2017",
transform=transforms.Compose([
Resize(resize=input_image_size),
]))
epochs = 12
per_node_batch_size = 15
lr = 1e-4
num_workers = 4
print_interval = 100
apex = True
sync_bn = False
train.py文件如下:
import sys
import os
import argparse
import random
import shutil
import time
import warnings
import json
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')
import numpy as np
from thop import profile
from thop import clever_format
import apex
from apex import amp
from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import Config
from public.detection.dataset.cocodataset import COCODataPrefetcher, collater
from public.detection.models.loss import RetinaLoss
from public.detection.models.decode import RetinaDecoder
from public.detection.models.retinanet import resnet50_retinanet
from public.imagenet.utils import get_logger
from pycocotools.cocoeval import COCOeval
def parse_args():
parser = argparse.ArgumentParser(
description='PyTorch COCO Detection Distributed Training')
parser.add_argument('--network',
type=str,
default=Config.network,
help='name of network')
parser.add_argument('--lr',
type=float,
default=Config.lr,
help='learning rate')
parser.add_argument('--epochs',
type=int,
default=Config.epochs,
help='num of training epochs')
parser.add_argument('--per_node_batch_size',
type=int,
default=Config.per_node_batch_size,
help='per_node batch size')
parser.add_argument('--pretrained',
type=bool,
default=Config.pretrained,
help='load pretrained model params or not')
parser.add_argument('--num_classes',
type=int,
default=Config.num_classes,
help='model classification num')
parser.add_argument('--input_image_size',
type=int,
default=Config.input_image_size,
help='input image size')
parser.add_argument('--num_workers',
type=int,
default=Config.num_workers,
help='number of worker to load data')
parser.add_argument('--resume',
type=str,
default=Config.resume,
help='put the path to resuming file if needed')
parser.add_argument('--checkpoints',
type=str,
default=Config.checkpoint_path,
help='path for saving trained models')
parser.add_argument('--log',
type=str,
default=Config.log,
help='path to save log')
parser.add_argument('--evaluate',
type=str,
default=Config.evaluate,
help='path for evaluate model')
parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
parser.add_argument('--print_interval',
type=bool,
default=Config.print_interval,
help='print interval')
parser.add_argument('--apex',
type=bool,
default=Config.apex,
help='use apex or not')
parser.add_argument('--sync_bn',
type=bool,
default=Config.sync_bn,
help='use sync bn or not')
parser.add_argument('--local_rank',
type=int,
default=0,
help='LOCAL_PROCESS_RANK')
return parser.parse_args()
def validate(val_dataset, model, decoder):
model = model.module
# switch to evaluate mode
model.eval()
with torch.no_grad():
all_eval_result = evaluate_coco(val_dataset, model, decoder)
return all_eval_result
def evaluate_coco(val_dataset, model, decoder):
results, image_ids = [], []
for index in range(len(val_dataset)):
data = val_dataset[index]
scale = data['scale']
cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(
2, 0, 1).float().unsqueeze(dim=0))
scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)
scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
boxes /= scale
# make sure decode batch_size=1
# scores shape:[1,max_detection_num]
# classes shape:[1,max_detection_num]
# bboxes shape[1,max_detection_num,4]
assert scores.shape[0] == 1
scores = scores.squeeze(0)
classes = classes.squeeze(0)
boxes = boxes.squeeze(0)
# for coco_eval,we need [x_min,y_min,w,h] format pred boxes
boxes[:, 2:] -= boxes[:, :2]
for object_score, object_class, object_box in zip(
scores, classes, boxes):
object_score = float(object_score)
object_class = int(object_class)
object_box = object_box.tolist()
if object_class == -1:
break
image_result = {
'image_id':
val_dataset.image_ids[index],
'category_id':
val_dataset.find_category_id_from_coco_label(object_class),
'score':
object_score,
'bbox':
object_box,
}
results.append(image_result)
image_ids.append(val_dataset.image_ids[index])
print('{}/{}'.format(index, len(val_dataset)), end='\r')
if not len(results):
print("No target detected in test set images")
return
json.dump(results,
open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
indent=4)
# load results in COCO evaluation tool
coco_true = val_dataset.coco
coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
val_dataset.set_name))
coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
all_eval_result = coco_eval.stats
return all_eval_result
def main():
args = parse_args()
global local_rank
local_rank = args.local_rank
if local_rank == 0:
global logger
logger = get_logger(__name__, args.log)
torch.cuda.empty_cache()
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
global gpus_num
gpus_num = torch.cuda.device_count()
if local_rank == 0:
logger.info(f'use {gpus_num} gpus')
logger.info(f"args: {args}")
cudnn.benchmark = True
cudnn.enabled = True
start_time = time.time()
# dataset and dataloader
if local_rank == 0:
logger.info('start loading data')
train_sampler = torch.utils.data.distributed.DistributedSampler(
Config.train_dataset, shuffle=True)
train_loader = DataLoader(Config.train_dataset,
batch_size=args.per_node_batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collater,
sampler=train_sampler)
if local_rank == 0:
logger.info('finish loading data')
model = resnet50_retinanet(**{
"pretrained": args.pretrained,
"num_classes": args.num_classes,
})
for name, param in model.named_parameters():
if local_rank == 0:
logger.info(f"{name},{param.requires_grad}")
flops_input = torch.randn(1, 3, args.input_image_size,
args.input_image_size)
flops, params = profile(model, inputs=(flops_input, ))
flops, params = clever_format([flops, params], "%.3f")
if local_rank == 0:
logger.info(
f"model: '{args.network}', flops: {flops}, params: {params}")
criterion = RetinaLoss(image_w=args.input_image_size,
image_h=args.input_image_size).cuda()
decoder = RetinaDecoder(image_w=args.input_image_size,
image_h=args.input_image_size).cuda()
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
patience=3,
verbose=True)
if args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.apex:
amp.register_float_function(torch, 'sigmoid')
amp.register_float_function(torch, 'softmax')
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
model = apex.parallel.DistributedDataParallel(model,
delay_allreduce=True)
if args.sync_bn:
model = apex.parallel.convert_syncbn_model(model)
else:
model = nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
if args.evaluate:
if not os.path.isfile(args.evaluate):
if local_rank == 0:
logger.exception(
'{} is not a file, please check it again'.format(
args.resume))
sys.exit(-1)
if local_rank == 0:
logger.info('start only evaluating')
logger.info(f"start resuming model from {args.evaluate}")
checkpoint = torch.load(args.evaluate,
map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
if local_rank == 0:
all_eval_result = validate(Config.val_dataset, model, decoder)
if all_eval_result is not None:
logger.info(
f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
)
return
best_map = 0.0
start_epoch = 1
# resume training
if os.path.exists(args.resume):
if local_rank == 0:
logger.info(f"start resuming model from {args.resume}")
checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
start_epoch += checkpoint['epoch']
best_map = checkpoint['best_map']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if local_rank == 0:
logger.info(
f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}"
)
if not os.path.exists(args.checkpoints):
os.makedirs(args.checkpoints)
if local_rank == 0:
logger.info('start training')
for epoch in range(start_epoch, args.epochs + 1):
train_sampler.set_epoch(epoch)
cls_losses, reg_losses, losses = train(train_loader, model, criterion,
optimizer, scheduler, epoch,
args)
if local_rank == 0:
logger.info(
f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}"
)
if epoch % 5 == 0 or epoch == args.epochs:
if local_rank == 0:
all_eval_result = validate(Config.val_dataset, model, decoder)
logger.info(f"eval done.")
if all_eval_result is not None:
logger.info(
f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
)
if all_eval_result[0] > best_map:
torch.save(model.module.state_dict(),
os.path.join(args.checkpoints, "best.pth"))
best_map = all_eval_result[0]
if local_rank == 0:
torch.save(
{
'epoch': epoch,
'best_map': best_map,
'cls_loss': cls_losses,
'reg_loss': reg_losses,
'loss': losses,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, os.path.join(args.checkpoints, 'latest.pth'))
if local_rank == 0:
logger.info(f"finish training, best_map: {best_map:.3f}")
training_time = (time.time() - start_time) / 3600
if local_rank == 0:
logger.info(
f"finish training, total training time: {training_time:.2f} hours")
def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):
cls_losses, reg_losses, losses = [], [], []
# switch to train mode
model.train()
iters = len(train_loader.dataset) // (args.per_node_batch_size * gpus_num)
prefetcher = COCODataPrefetcher(train_loader)
images, annotations = prefetcher.next()
iter_index = 1
while images is not None:
images, annotations = images.cuda().float(), annotations.cuda()
cls_heads, reg_heads, batch_anchors = model(images)
cls_loss, reg_loss = criterion(cls_heads, reg_heads, batch_anchors,
annotations)
loss = cls_loss + reg_loss
if cls_loss == 0.0 or reg_loss == 0.0:
optimizer.zero_grad()
continue
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
optimizer.zero_grad()
cls_losses.append(cls_loss.item())
reg_losses.append(reg_loss.item())
losses.append(loss.item())
images, annotations = prefetcher.next()
if local_rank == 0 and iter_index % args.print_interval == 0:
logger.info(
f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, loss_total: {loss.item():.2f}"
)
iter_index += 1
scheduler.step(np.mean(losses))
return np.mean(cls_losses), np.mean(reg_losses), np.mean(losses)
if __name__ == '__main__':
main()
启动训练的train.sh:
python -m torch.distributed.launch --nproc_per_node=1 --master_addr 127.0.0.1 --master_port 20001 train.py
模型在COCO数据集上的性能表现如下(输入分辨率为600,约等于RetinaNet论文中的分辨率450):
Network | batch | gpu-num | apex | syncbn | epoch5-mAP-loss | epoch10-mAP-loss | epoch12-mAP-loss | one-epoch-training-times |
---|---|---|---|---|---|---|---|---|
ResNet50-RetinaNet | 16 | 2 | no | yes | 0.249,0.59 | 0.275,0.47 | 0.279,0.44 | 2h1min |
ResNet50-RetinaNet | 16 | 2 | no | no | 0.251,0.60 | 0.274,0.48 | 0.278,0.45 | 1h56min |
ResNet50-RetinaNet | 15 | 1 | yes | no | 0.255,0.59 | 0.272,0.48 | 0.279,0.45 | 2h28min |
ResNet50-RetinaNet-aug | 15 | 1 | yes | no | 0.251,0.62 | 0.281,0.53 | 0.287,0.51 | 2h32min |
上面所有实验均在DistributedDataParallel模式下训练。如果只用一张显卡,那么使用sync BN和不使用sync BN是完全一样的。所有实验训练时使用RandomFlip+Resize数据增强,测试时直接Resize。带-aug表示训练时还额外使用了RandomCrop和RandomTranslate数据增强。GPU全部使用RTX 2080ti。0.255,0.59表示mAP为0.255,此时的总loss为0.59。2h28min表示2小时28分。
根据结果,在同样数据增强情况下我的代码训练出来的RetinaNet(0.279)要比论文中低3.2个点(论文中分辨率450时点数推算应该在0.311左右),这应该是由于使用了Adam优化器代替SGD优化器,以及上一篇文章中提出的问题1、3带来的点数差距。
在COCO数据集的标注中,有一个属性iscrowd。当iscrowd=1时,表明标注的为一群目标(比如一群人),当iscrowd=0时,表明标注的为单一目标。在前面的所有实验结果中,训练时读取的标注目标均为(self.coco.getAnnIds中iscrowd=None)iscrowd=0+iscrowd=1的所有目标。
我查阅了detectron(https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/json_dataset.py)和detectron2(https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/coco.py)中读取COCO数据集的代码,发现它们在目标检测和分割训练时均过滤了iscrowd=1的目标,没有将其用于训练。因此,我去除了iscrowd=1的所有标注目标重新训练了一次(self.coco.getAnnIds中iscrowd取False)。
训练结果如下:
Network | batch | gpu-num | apex | syncbn | epoch5-mAP-loss | epoch5-mAP-loss | epoch12-mAP-loss | one-epoch-training-times |
---|---|---|---|---|---|---|---|---|
ResNet50-RetinaNet-aug | 15 | 1 | yes | no | 0.251,0.62 | 0.281,0.53 | 0.287,0.51 | 2h32min |
ResNet50-RetinaNet-aug-iscrowd | 15 | 1 | yes | no | 0.254,0.62 | 0.280,0.53 | 0.286,0.50 | 2h31min |
ResNet50-RetinaNet-aug即上面分布式训练结果中最后一项,ResNet50-RetinaNet-aug-iscrowd即上面分布式训练结果中最后一项基础上self.coco.getAnnIds中iscrowd取False后的结果。可以看出两者差别很小,不过为了与其他框架训练结果对齐,在之后的改进实验中,我统一使用ResNet50-RetinaNet-aug-iscrowd作为baseline。