Yolo系列因为其灵活性,一直是目标检测热门算法,近期旷视的研究者提出了Yolox高性能目标检测器,将Anchor free引入了Yolo算法,是除YOLOV1之后,第二个将Anchor free研究思路用到Yolo的算法。
Yolox的创新点除了Anchor free思想之外,个人认为最重要的是解耦头(Decoupled head)的使用,在一定程度上解决了分类和回归任务问题;当然,除此之外,标签分配策略(SimOTA)也是其的一大亮点。
1. 文章链接 :https://arxiv.org/abs/2107.08430;
2. 源码:https://github.com/Megvii-BaseDetection/YOLOX;
文章作者使用COCO数据集格式训练模型,在这里使用另一种格式VOC数据集来训练自己的模型
2.1 VOC数据集格式
Annotations文件夹放置.xml标签文件;JPEGImages文件夹放置训练原图;ImageSets文件夹下的Main文件夹存放训练、验证、测试数据集的.txt文件。 具体格式如下:
test.txt,train.txt,trainval.txt,val.txt文件内容如下图所示,存放的是各个图片的名称。
利用VOC格式训练Yolox模型这些.txt文件是必须的;相应的生成代码如下:
import os
import random
random.seed(0)
xmlfilepath='VOC数据集Annotations文件夹路径'#xml文件存放地址,在训练自己数据集的时候,改成自己的数据路径
saveBasePath="VOC数据集ImageSets\\Main文件夹路径"#存放test.txt,train.txt,trainval.txt,val.txt文件路径
#----------------------------------------------------------------------#
# 根据自己的需求更改trainval_percent和train_percent的比例
#----------------------------------------------------------------------#
trainval_percent=0.9
train_percent=1
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num=len(total_xml)
list=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list,tv)
train=random.sample(trainval,tr)
print("train and val size",tv)
print("traub suze",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
for i in list:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
3.1 修改检测类别
根据自己的需求修改检测类别名称,我做的项目是人脸检测,检测类别只有face,所以VOC_CLASS内容只写“face”。检测类别名称对应源码文件:D:\YOLOX-main\yolox\data\datasets\voc_classes.py,只修改该文件夹下的检测类别,各个类别间必须用逗号隔开,最后一个类别也必须加逗号,如下图所示:
3.2 修改训练参数
根据需求修改D:\YOLOX-main\yolox\exp\yolox_base.py文件下的相关参数,在这我只修改了self.num_classes = 1,其他参数可根据自己需求进行修改。
import os
import random
import torch
import torch.distributed as dist
import torch.nn as nn
from .base_exp import BaseExp
class Exp(BaseExp):
def __init__(self):
super().__init__()
# ---------------- model config ---------------- #
self.num_classes = 1
self.depth = 1.00
self.width = 1.00
# ---------------- dataloader config ---------------- #
# set worker to 4 for shorter dataloader init time
self.data_num_workers = 4
self.input_size = (640, 640)
self.random_size = (14, 26)
self.data_dir = None
self.train_ann = "instances_train2017.json"
self.val_ann = "instances_val2017.json"
# --------------- transform config ----------------- #
self.degrees = 10.0
self.translate = 0.1
self.scale = (0.1, 2)
self.mscale = (0.8, 1.6)
self.shear = 2.0
self.perspective = 0.0
self.enable_mixup = True
# -------------- training config --------------------- #
self.warmup_epochs = 5
self.max_epoch = 200
self.warmup_lr = 0
self.basic_lr_per_img = 0.01 / 64.0
self.scheduler = "yoloxwarmcos"
self.no_aug_epochs = 15
self.min_lr_ratio = 0.05
self.ema = True
self.weight_decay = 5e-4
self.momentum = 0.9
self.print_interval = 10
self.eval_interval = 10
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
# ----------------- testing config ------------------ #
self.test_size = (640, 640)
self.test_conf = 0.01
self.nmsthre = 0.65
def get_model(self):
from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
if getattr(self, "model", None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)
head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels)
self.model = YOLOX(backbone, head)
self.model.apply(init_yolo)
self.model.head.initialize_biases(1e-2)
return self.model
def get_data_loader(self, batch_size, is_distributed, no_aug=False):
from yolox.data import (
COCODataset,
TrainTransform,
YoloBatchSampler,
DataLoader,
InfiniteSampler,
MosaicDetection,
)
dataset = COCODataset(
data_dir=self.data_dir,
json_file=self.train_ann,
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=50,
),
)
dataset = MosaicDetection(
dataset,
mosaic=not no_aug,
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=120,
),
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
shear=self.shear,
perspective=self.perspective,
enable_mixup=self.enable_mixup,
)
self.dataset = dataset
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = InfiniteSampler(
len(self.dataset), seed=self.seed if self.seed else 0
)
batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
input_dimension=self.input_size,
mosaic=not no_aug,
)
dataloader_kwargs = {
"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
return train_loader
def random_resize(self, data_loader, epoch, rank, is_distributed):
tensor = torch.LongTensor(2).cuda()
if rank == 0:
size_factor = self.input_size[1] * 1. / self.input_size[0]
size = random.randint(*self.random_size)
size = (int(32 * size), 32 * int(size * size_factor))
tensor[0] = size[0]
tensor[1] = size[1]
if is_distributed:
dist.barrier()
dist.broadcast(tensor, 0)
input_size = data_loader.change_input_dim(
multiple=(tensor[0].item(), tensor[1].item()), random_range=None
)
return input_size
def get_optimizer(self, batch_size):
if "optimizer" not in self.__dict__:
if self.warmup_epochs > 0:
lr = self.warmup_lr
else:
lr = self.basic_lr_per_img * batch_size
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
for k, v in self.model.named_modules():
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
pg2.append(v.bias) # biases
if isinstance(v, nn.BatchNorm2d) or "bn" in k:
pg0.append(v.weight) # no decay
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
pg1.append(v.weight) # apply decay
optimizer = torch.optim.SGD(
pg0, lr=lr, momentum=self.momentum, nesterov=True
)
optimizer.add_param_group(
{
"params": pg1, "weight_decay": self.weight_decay}
) # add pg1 with weight_decay
optimizer.add_param_group({
"params": pg2})
self.optimizer = optimizer
return self.optimizer
def get_lr_scheduler(self, lr, iters_per_epoch):
from yolox.utils import LRScheduler
scheduler = LRScheduler(
self.scheduler,
lr,
iters_per_epoch,
self.max_epoch,
warmup_epochs=self.warmup_epochs,
warmup_lr_start=self.warmup_lr,
no_aug_epochs=self.no_aug_epochs,
min_lr_ratio=self.min_lr_ratio,
)
return scheduler
def get_eval_loader(self, batch_size, is_distributed, testdev=False):
from yolox.data import COCODataset, ValTransform
valdataset = COCODataset(
data_dir=self.data_dir,
json_file=self.val_ann if not testdev else "image_info_test-dev2017.json",
name="val2017" if not testdev else "test2017",
img_size=self.test_size,
preproc=ValTransform(
rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
)
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = torch.utils.data.distributed.DistributedSampler(
valdataset, shuffle=False
)
else:
sampler = torch.utils.data.SequentialSampler(valdataset)
dataloader_kwargs = {
"num_workers": self.data_num_workers,
"pin_memory": True,
"sampler": sampler,
}
dataloader_kwargs["batch_size"] = batch_size
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
return val_loader
def get_evaluator(self, batch_size, is_distributed, testdev=False):
from yolox.evaluators import COCOEvaluator
val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
evaluator = COCOEvaluator(
dataloader=val_loader,
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
testdev=testdev,
)
return evaluator
def eval(self, model, evaluator, is_distributed, half=False):
return evaluator.evaluate(model, is_distributed, half)
3.3 修改训练数据集路径
Yolox作者准备了VOC数据集文件,D:\YOLOX-main\exps\example\yolox_voc\yolox_voc_s.py,但是需要对其进行修改,这个文件需要修改的有:
(1)self.num_classes = 1;
(2)get_data_loder 下的 VOCDetection 下的 image_sets=[(‘2007’, ‘trainval’), (‘2012’, ‘trainval’)],将其修改为image_sets=[(‘2007’, ‘train’)],;
*第一点修改根据个人需求决定,第二点我是根据自己使用习惯进行修改的,**如果准备的VOC数据集里面有trainval.txt文件,*那么可以不对源码进行修改。
==有的人在训练的过程中会出现AP=0的现象,这是因为函数get_eval_loader下的image_sets=[(‘2007’, ‘test’)]下test文件对应的test.txt里面内容为空,在第二步数据准备的时候一定要使得test.txt文件内容存在。==修改代码如下:
# encoding: utf-8
import os
import torch
import torch.distributed as dist
from yolox.data import get_yolox_datadir
from yolox.exp import Exp as MyExp
class Exp(MyExp):
def __init__(self):
super(Exp, self).__init__()
self.num_classes = 1
self.depth = 0.33
self.width = 0.50
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
def get_data_loader(self, batch_size, is_distributed, no_aug=False):
from yolox.data import (
VOCDetection,
TrainTransform,
YoloBatchSampler,
DataLoader,
InfiniteSampler,
MosaicDetection,
)
dataset = VOCDetection(
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
image_sets=[('2007', 'train')],
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=50,
),
)
dataset = MosaicDetection(
dataset,
mosaic=not no_aug,
img_size=self.input_size,
preproc=TrainTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
max_labels=120,
),
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
shear=self.shear,
perspective=self.perspective,
enable_mixup=self.enable_mixup,
)
self.dataset = dataset
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = InfiniteSampler(
len(self.dataset), seed=self.seed if self.seed else 0
)
batch_sampler = YoloBatchSampler(
sampler=sampler,
batch_size=batch_size,
drop_last=False,
input_dimension=self.input_size,
mosaic=not no_aug,
)
dataloader_kwargs = {
"num_workers": self.data_num_workers, "pin_memory": True}
dataloader_kwargs["batch_sampler"] = batch_sampler
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
return train_loader
def get_eval_loader(self, batch_size, is_distributed, testdev=False):
from yolox.data import VOCDetection, ValTransform
valdataset = VOCDetection(
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
image_sets=[('2007', 'test')],
img_size=self.test_size,
preproc=ValTransform(
rgb_means=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
)
if is_distributed:
batch_size = batch_size // dist.get_world_size()
sampler = torch.utils.data.distributed.DistributedSampler(
valdataset, shuffle=False
)
else:
sampler = torch.utils.data.SequentialSampler(valdataset)
dataloader_kwargs = {
"num_workers": self.data_num_workers,
"pin_memory": True,
"sampler": sampler,
}
dataloader_kwargs["batch_size"] = batch_size
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
return val_loader
def get_evaluator(self, batch_size, is_distributed, testdev=False):
from yolox.evaluators import VOCEvaluator
val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
evaluator = VOCEvaluator(
dataloader=val_loader,
img_size=self.test_size,
confthre=self.test_conf,
nmsthre=self.nmsthre,
num_classes=self.num_classes,
)
return evaluator
3.4 修改Voc.py文件中Annotations文件的读取格式
修改源码 D:\YOLOX-main\yolox\data\datasets\voc.py文件下的_do_python_eval函数,将annopath = os.path.join(rootpath, "Annotations", "{:s}.xml")
修改为:annopath = os.path.join(rootpath, "Annotations", "{}.xml")
有两种训练方式:
(1)直接在train.py文件下修改超参数;
(2)在终端训练,源码中的ReadMe有说明;
选用第一种方式训练,下面代码是我作的修改,在要修改的地方做了注释。
import argparse
import random
import warnings
from loguru import logger
import torch
import torch.backends.cudnn as cudnn
from yolox.core import Trainer, launch
from yolox.exp import get_exp
from yolox.utils import configure_nccl, configure_omp
def make_parser():
parser = argparse.ArgumentParser("YOLOX train parser")
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default='yolox-s', help="model name")#要使用那个模型,写那个模型的名称,yolox-s对应的yolox-s.model,还可以写成yolox-l等模型
# distributed
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--dist-url",
default=None,
type=str,
help="url used to set up distributed training",
)
parser.add_argument("-b", "--batch-size", type=int, default=16, help="batch size")#根据电脑显卡的显存大小设置对应的batch-size
parser.add_argument(
"-d", "--devices", default=0, type=int, help="device for training"#只有一张显卡,所以设置为0
)
parser.add_argument(
"-f",
"--exp_file",
default='D:\\YOLOX-main\\exps\\example\\yolox_voc\\yolox_voc_s.py',#训练数据集绝对路径
type=str,
help="plz input your expriment description file",
)
parser.add_argument(
"--resume", default=False, action="store_true", help="resume training"
)
parser.add_argument("-c", "--ckpt", default='D:\\YOLOX-main\\yolox_s.pth', type=str, help="checkpoint file")#预训练权重
parser.add_argument(
"-e",
"--start_epoch",
default=0,
type=int,
help="resume training start epoch",
)
parser.add_argument(
"--num_machines", default=1, type=int, help="num of node for training"
)
parser.add_argument(
"--machine_rank", default=0, type=int, help="node rank for multi-node training"
)
parser.add_argument(
"--fp16",
dest="fp16",
default=True,
action="store_true",
help="Adopting mix precision training.",
)
parser.add_argument(
"-o",
"--occupy",
dest="occupy",
default=False,
action="store_true",
help="occupy GPU memory first for training.",
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
@logger.catch
def main(exp, args):
if exp.seed is not None:
random.seed(exp.seed)
torch.manual_seed(exp.seed)
cudnn.deterministic = True
warnings.warn(
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
"which can slow down your training considerably! You may see unexpected behavior "
"when restarting from checkpoints."
)
# set environment variables for distributed training
configure_nccl()
configure_omp()
cudnn.benchmark = True
trainer = Trainer(exp, args)
trainer.train()
if __name__ == "__main__":
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
if not args.experiment_name:
args.experiment_name = exp.exp_name
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
assert num_gpu <= torch.cuda.device_count()
dist_url = "auto" if args.dist_url is None else args.dist_url
launch(
main,
num_gpu,
args.num_machines,
args.machine_rank,
backend=args.dist_backend,
dist_url=dist_url,
args=(exp, args),
)
yolox-s预训练权重我放在百度网盘:链接:https://pan.baidu.com/s/1kmNzX0FE555P1t9fzBqIJw 提取码:0r71
。
使用demo.py文件来检测,其中要注意引用VOC数据集的类别。
修改1
from yolox.data.datasets import voc_classes#在训练的时候使用voc数据集的类别,所以在检测的时候要引入voc文件对应的类
修改2
predictor = Predictor(model, exp, voc_classes.VOC_CLASSES,, trt_file, decoder, args.device)
修改3
cls_names=voc_classes.VOC_CLASSES,
修改后代码
import argparse
import os
import time
from loguru import logger
import cv2
import torch
from yolox.data.data_augment import preproc
from yolox.data.datasets import voc_classes
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, vis
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
def make_parser():
parser = argparse.ArgumentParser("YOLOX Demo!")
parser.add_argument(
"--demo", default="image", help="demo type, eg. image, video and webcam"
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default='yolox-s', help="model name")
parser.add_argument(
"--path", default="要检测图像的路径", help="path to images or video"
)
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
"--save_result",
action="store_true",
default=True,#如果要保存检测结果,需要设置默认值为True
help="whether to save the inference result of image/video",
)
# exp file
parser.add_argument(
"-f",
"--exp_file",
default='D:\\YOLOX-main\\exps\\default\\yolox_s.py',
type=str,
help="pls input your expriment description file",
)
parser.add_argument("-c", "--ckpt", default='训练好模型的路径', type=str, help="ckpt for eval")
parser.add_argument(
"--device",
default="gpu",
type=str,
help="device to run our model, can either be cpu or gpu",
)
#------------------------------------------------------------------#
#根据自己的项目需求进行修改
parser.add_argument("--conf", default=0.35, type=float, help="test conf")
parser.add_argument("--nms", default=0.5, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=640, type=int, help="test img size")
parser.add_argument(
"--fp16",
dest="fp16",
default=True,
action="store_true",
help="Adopting mix precision evaluating.",
)
parser.add_argument(
"--fuse",
dest="fuse",
default=False,
action="store_true",
help="Fuse conv and bn for testing.",
)
parser.add_argument(
"--trt",
dest="trt",
default=False,
action="store_true",
help="Using TensorRT model for testing.",
)
return parser
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in IMAGE_EXT:
image_names.append(apath)
return image_names
class Predictor(object):
def __init__(
self,
model,
exp,
cls_names=voc_classes.VOC_CLASSES,
trt_file=None,
decoder=None,
device="gpu",
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
if trt_file is not None:
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
self.model(x)
self.model = model_trt
self.rgb_means = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
def inference(self, img):
img_info = {
"id": 0}
if isinstance(img, str):
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info["file_name"] = None
height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0)
if self.device == "gpu":
img = img.cuda()
with torch.no_grad():
t0 = time.time()
outputs = self.model(img)
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre, self.nmsthre
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info
def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.cpu()
bboxes = output[:, 0:4]
# preprocessing: resize
bboxes /= ratio
cls = output[:, 6]
scores = output[:, 4] * output[:, 5]
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res
def image_demo(predictor, vis_folder, path, current_time, save_result):
if os.path.isdir(path):
files = get_image_list(path)
else:
files = [path]
files.sort()
for image_name in files:
outputs, img_info = predictor.inference(image_name)
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
if save_result:
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
save_file_name = os.path.join(save_folder, os.path.basename(image_name))
logger.info("Saving detection result in {}".format(save_file_name))
cv2.imwrite(save_file_name, result_image)
ch = cv2.waitKey(0)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
def imageflow_demo(predictor, vis_folder, current_time, args):
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
if args.demo == "video":
save_path = os.path.join(save_folder, args.path.split("/")[-1])
else:
save_path = os.path.join(save_folder, "camera.mp4")
logger.info(f"video save_path is {
save_path}")
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
while True:
ret_val, frame = cap.read()
if ret_val:
outputs, img_info = predictor.inference(frame)
result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
if args.save_result:
vid_writer.write(result_frame)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
def main(exp, args):
if not args.experiment_name:
args.experiment_name = exp.exp_name
file_name = os.path.join(exp.output_dir, args.experiment_name)
os.makedirs(file_name, exist_ok=True)
if args.save_result:
vis_folder = os.path.join(file_name, "vis_res")
os.makedirs(vis_folder, exist_ok=True)
if args.trt:
args.device = "gpu"
logger.info("Args: {}".format(args))
if args.conf is not None:
exp.test_conf = args.conf
if args.nms is not None:
exp.nmsthre = args.nms
if args.tsize is not None:
exp.test_size = (args.tsize, args.tsize)
model = exp.get_model()
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
if args.device == "gpu":
model.cuda()
model.eval()
if not args.trt:
if args.ckpt is None:
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
logger.info("loading checkpoint")
ckpt = torch.load(ckpt_file, map_location="cpu")
# load the model state dict
model.load_state_dict(ckpt["model"])
logger.info("loaded checkpoint done.")
if args.fuse:
logger.info("\tFusing model...")
model = fuse_model(model)
if args.trt:
assert not args.fuse, "TensorRT model is not support model fusing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
logger.info("Using TensorRT to inference")
else:
trt_file = None
decoder = None
predictor = Predictor(model, exp, voc_classes.VOC_CLASSES, trt_file, decoder, args.device)
current_time = time.localtime()
if args.demo == "image":
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
elif args.demo == "video" or args.demo == "webcam":
imageflow_demo(predictor, vis_folder, current_time, args)
if __name__ == "__main__":
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
main(exp, args)