yolov5代码解读--train.py

yolov5代码解读

  • 前言
  • 函数train()
  • 总结


前言

前一篇博客大致对yolov5的一些前期准备和训练参数等做了整理(YOLO v5 代码解读及训练、测试实操),此篇博客主要对项目中的train.py内容进行详细解读,以方便大家学习。


函数train()

train.py函数涉及的篇幅比较大,为提高阅读性,本博客仅提供部门核心进行讲解,详细的完整代码,见网盘地址(提取码:wbqu)。

①超参数及训练参数配置

#获取轮次、批次、总批次(涉及到分布式训练)、权重、进程序号(主要用于分布式训练)#
epochs, batch_size, total_batch_size, weights, rank = \
  opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

# Save run settings
#hyp:超参数、opt:训练参数#
with open(log_dir / 'hyp.yaml', 'w') as f:
    yaml.dump(hyp, f, sort_keys=False)
with open(log_dir / 'opt.yaml', 'w') as f:
    yaml.dump(vars(opt), f, sort_keys=False)

# Configure
cuda = device.type != 'cpu'
init_seeds(2 + rank)#设置随机种子

#加载数据配置信息#
with open(opt.data) as f:
     data_dict = yaml.load(f, Loader=yaml.FullLoader)  # data dict

②从yaml配置文件中获取类别数量和名称,如果设置了opt.single_cls,此时nc=1

nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names'])  # number classes, names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data)  # check

③模型创建

#模型创建(两种方式:opt.cfg/ckpt['model'].yaml),区别在于是否resume{中断后重新开始},resume时将opt.cfg设为空,此时按照ckpt['model'].yaml创建模型;
#resume同时影响到了下面的anchor加载,简言之,如果resume则anchor不加载;<目的:防止用户自定义了anchor之后再resume,从而导致anchor被覆盖>
        """
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device)  # create
exclude = ['anchor'] if opt.cfg else []  # exclude keys
state_dict = ckpt['model'].float().state_dict()  # to FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude)  # intersect
model.load_state_dict(state_dict, strict=False)  # load

④获取热身训练的迭代次数

# 获取热身训练的迭代次数
nw = max(3 * nb, 1e3)  # number of warmup iterations, max(3 epochs, 1k iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
maps = np.zeros(nc)  # mAP per class,初始化mAP和result
results = (0, 0, 0, 0, 0, 0, 0)  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'

⑤设置多尺度训练

# Multi-scale
# 设置多尺度训练,从imgsz * 0.5, imgsz * 1.5 + gs随机选取尺寸
if opt.multi_scale:
   sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs  # size
   sf = sz / max(imgs.shape[2:])  # scale factor
   if sf != 1:
      ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to gs-multiple)
      imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

⑥学习率衰减

# Scheduler,学习率衰减
lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard
scheduler.step()

⑦提高推理速度

if rank in [-1, 0]:
    # Strip optimizers
    """
    模型训练完后,strip_optimizer函数将optimizer从ckpt中去除;
    并且对模型进行model.half(), 将Float32的模型->Float16,
    可以减少模型大小,提高inference速度
    """
    n = opt.name if opt.name.isnumeric() else ''
    fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
    for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
        if os.path.exists(f1):
            os.rename(f1, f2)  # rename
            if str(f2).endswith('.pt'):  # is *.pt
                strip_optimizer(f2)  # strip optimizer
                os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None  # upload
    # Finish
    if not opt.evolve:
        plot_results(save_dir=log_dir)  # save as results.png
    logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))

总结

由于网络问题官方地址给出的yolov5lsxm.pt预训练模型不容易下下来,我把百度云链接地址附上,需要的可以取一下:
yolov5lsxm.pt预训练模型(提取码:asox)
train.py完整版(提取码:wbqu)

ps:安全性考虑,没有放置永久链接,如果链接到期,请在评论区浏览,博主看到后会第一时间回复

你可能感兴趣的:(yolov5,代码解读,pytorch,机器学习,深度学习)