yolov7自动停止(设置patience)且输出最优模型时的PR图(test best.py)

步骤1:在utils文件夹下的torch_utils.py里添加如下class

class EarlyStopping:
    # YOLOv5 simple early stopper
    def __init__(self, patience=30):
        self.best_fitness = 0.0  # i.e. mAP
        self.best_epoch = 0
        self.patience = patience or float('inf')  # epochs to wait after fitness stops improving to stop
        self.possible_stop = False  # possible stop may occur next epoch

    def __call__(self, epoch, fitness):
        if fitness >= self.best_fitness:  # >= 0 to allow for early zero-fitness stage of training
            self.best_epoch = epoch
            self.best_fitness = fitness
        delta = epoch - self.best_epoch  # epochs without improvement
        self.possible_stop = delta >= (self.patience - 1)  # possible stop may occur next epoch
        stop = delta >= self.patience  # stop training if patience exceeded
        if stop:
            logger.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
                        f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
                        f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
                        f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
        return stop

步骤2:在train.py前面添加模块

from utils.torch_utils import EarlyStopping

步骤3:在train.py文件里大概第三百行start training中scaler变量下面添加如下两行

    stopper: EarlyStopping
    stopper, stop = EarlyStopping(patience=opt.patience), False

步骤4:在train.py文件里大概450行的位置,#Save model上方加入如下代码(并将原来的#Update best mAP部分删除)

            # Update best mAP
            fi = fitness(np.array(results).reshape(1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
            stop = stopper(epoch=epoch, fitness=fi)  # early stop check
            if fi > best_fitness:
                best_fitness = fi
            wandb_logger.end_epoch(best_result=best_fitness == fi)

步骤5:在train.py文件里大概490行的位置,#end epoch上方加入如下代码

        # EarlyStopping
        if rank != -1:  # if DDP training
            broadcast_list = [stop if rank == 0 else None]
            dist.broadcast_object_list(broadcast_list, 0)  # broadcast 'stop' to all ranks
            if rank != 0:
                stop = broadcast_list[0]
        if stop:
            break  # must break all DDP ranks

步骤6:在train.py文件里大概580行的位置,加入一行

    parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')

步骤7:在train.py文件里大概425行的位置,将原先# Calculate mAP的部分更改为

            if not opt.notest or final_epoch:  # Calculate mAP
                wandb_logger.current_epoch = epoch + 1
                results, maps, times = test.test(data_dict,
                                                 batch_size=batch_size * 2,
                                                 imgsz=imgsz_test,
                                                 model=ema.ema,
                                                 single_cls=opt.single_cls,
                                                 dataloader=testloader,
                                                 save_dir=save_dir,
                                                 verbose=nc < 50 ,
                                                 plots=False,
                                                 wandb_logger=wandb_logger,
                                                 compute_loss=compute_loss,
                                                 is_coco=is_coco)

步骤8:在train.py文件里大概500行的位置,将原先# Test best.pt的部分更改为

        # Test best.pt
        logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        results, _, _ = test.test(opt.data,
                                  batch_size=batch_size * 2,
                                  imgsz=imgsz_test,
                                  conf_thres=0.001,
                                  iou_thres=0.7,
                                  model=attempt_load(best, device).half(),
                                  single_cls=opt.single_cls,
                                  dataloader=testloader,
                                  verbose=nc < 50,
                                  save_dir=save_dir,
                                  save_json=True,
                                  plots=plots,
                                  wandb_logger=wandb_logger,
                                  compute_loss=compute_loss,
                                  is_coco=is_coco)

步骤9:在test.py(注意换文件啦!)里大概293行的地方将iou-thres的默认值更改为0.7

    parser.add_argument('--iou-thres', type=float, default=0.7, help='IOU threshold for NMS')

在第26行的地方将iou-thres的值更改为0.7

         iou_thres=0.7,  # for NMS

~~~~~~~~~~~然后正常运行就可以啦~~~~~~~~完结*★,°*:.☆( ̄▽ ̄)/$:*.°★*~~~~~~~~~~~~~

你可能感兴趣的:(python,深度学习,pytorch)