train.py
from __future__ import print_function
import os
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import argparse
import torch.utils.data as data
from data import WiderFaceDetection, detection_collate, preproc, cfg_mnet, cfg_re50
from layers.modules import MultiBoxLoss
from layers.functions.prior_box import PriorBox
import time
import datetime
import math
from models.retinaface import RetinaFace
#argparse是一个Python模块:命令行选项、参数和子命令解析器。
#argparse 模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。 argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
#创建解析器
parser = argparse.ArgumentParser(description='Retinaface Training')
#添加参数:
#1:训练集label
parser.add_argument('--training_dataset', default='./data/widerface/train/label.txt', help='Training dataset directory')
#2:network
parser.add_argument('--network', default='mobile0.25', help='Backbone network mobile0.25 or resnet50')
#3:数据加载中使用的工作线程数
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading')
#4:初始学习率
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
#5:动量
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
#6:恢复网络
parser.add_argument('--resume_net', default=None, help='resume net for retraining')
#7:恢复epoch
parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining')
# SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。
#8:SGD的重量衰减
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
#9:SGD的Gamma更新
parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD')
#10:模型保存路径
parser.add_argument('--save_folder', default='./weights/', help='Location to save checkpoint models')
#parse_args()是将之前add_argument()定义的参数进行赋值,并返回相关的namespace。
args = parser.parse_args()
#检测有无./weights/文件夹如果没有就创建
if not os.path.exists(args.save_folder):
os.mkdir(args.save_folder)
#检测使用那种网络模型
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
#超参数
#----------------------
rgb_mean = (104, 117, 123) # bgr order分离通道
num_classes = 2#num_classes为标签类别总数
#----------------------
img_dim = cfg['image_size']
num_gpu = cfg['ngpu']
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
gpu_train = cfg['gpu_train']
num_workers = args.num_workers
momentum = args.momentum
weight_decay = args.weight_decay
initial_lr = args.lr
gamma = args.gamma
training_dataset = args.training_dataset
save_folder = args.save_folder
#生成网络
net = RetinaFace(cfg=cfg)
print("Printing net...")
print(net)
if args.resume_net is not None:
print('Loading resume network...')
#加载恢复网络
state_dict = torch.load(args.resume_net)
# create new OrderedDict that does not contain `module.`
#创建不包含“module”的新OrderedDict`
from collections import OrderedDict
new_state_dict = OrderedDict()
#灌参数
for k, v in state_dict.items():
head = k[:7]
if head == 'module.':
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
#提取神经网络
net.load_state_dict(new_state_dict)
#多gpu运行
if num_gpu > 1 and gpu_train:
net = torch.nn.DataParallel(net).cuda()
else:
net = net.cuda()
#增加运行效率
cudnn.benchmark = True
#优化器
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))#生成先验框
with torch.no_grad():#主要是用于停止autograd模块的工作,以起到加速和节省显存的作用
priors = priorbox.forward()#
priors = priors.cuda()
def train():
net.train()#用于训练
epoch = 0 + args.resume_epoch
print('Loading Dataset...')
dataset = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean))
epoch_size = math.ceil(len(dataset) / batch_size)
max_iter = max_epoch * epoch_size#最大通道
stepvalues = (cfg['decay1'] * epoch_size, cfg['decay2'] * epoch_size)#步幅值
step_index = 0#步幅指标
#开始通道
if args.resume_epoch > 0:
start_iter = args.resume_epoch * epoch_size
else:
start_iter = 0
#开始迭代
for iteration in range(start_iter, max_iter):
if iteration % epoch_size == 0:
# create batch iterator新一轮epoch加载数据,把全部数据又重新加载了,下面的next(batch_iterator)再逐batch_size地取数据
batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=detection_collate))#生成迭代器
if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > cfg['decay1']):
torch.save(net.state_dict(), save_folder + cfg['name']+ '_epoch_' + str(epoch) + '.pth')#存储pth文件
epoch += 1
load_t0 = time.time()
if iteration in stepvalues:
step_index += 1
lr = adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size)
# load train data
images, targets = next(batch_iterator) # batch_iterator一次性加载了数据,next操作就逐个batch_size地取出数据了
images = images.cuda()
targets = [anno.cuda() for anno in targets]
# forward
#前向传播
out = net(images)
# backprop
optimizer.zero_grad()#梯度清0
#计算损失
loss_l, loss_c, loss_landm = criterion(out, priors, targets)
#总损失
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
#反向传播
loss.backward()
#更新参数
optimizer.step()
load_t1 = time.time()
batch_time = load_t1 - load_t0
#预测还有多少时间
eta = int(batch_time * (max_iter - iteration))
print('Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'
.format(epoch, max_epoch, (iteration % epoch_size) + 1,
epoch_size, iteration + 1, max_iter, loss_l.item(), loss_c.item(), loss_landm.item(), lr, batch_time, str(datetime.timedelta(seconds=eta))))
torch.save(net.state_dict(), save_folder + cfg['name'] + '_Final.pth')
# torch.save(net.state_dict(), save_folder + 'Final_Retinaface.pth')
#重新设置学习率
def adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size):
"""Sets the learning rate
# Adapted from PyTorch Imagenet example:
# https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
warmup_epoch = -1
if epoch <= warmup_epoch:
lr = 1e-6 + (initial_lr-1e-6) * iteration / (epoch_size * warmup_epoch)
else:
lr = initial_lr * (gamma ** (step_index))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
if __name__ == '__main__':
train()