demo.py就是一个推理的脚本,告诉我们如何从头到尾的推理图像。
首先进入make_parser()方法,就是熟悉一下参数解析器:
def make_parser():
parser = argparse.ArgumentParser("YOLOX Demo!") """argparse.ArgumentParser是创建一个参数解析器,用以解析命令行的参数"""
"""parser.add_argument是往参数解析器里添加参数。该方法第一个参数为需要添加参数的简称,后面也可以跟该参数的全名。default参数代表默认值,help参数代表提示,type参数代表该参数的类型"""
parser.add_argument("demo", default="image", help="demo type, eg. image, video and webcam")"""试验的输入是图像、视频or摄像头"""
parser.add_argument("-expn", "--experiment-name", type=str, default=None)"""试验名称"""
parser.add_argument("-n", "--name", type=str, default=None, help="model name")"""模型的类型,有s、x、m、l等"""
parser.add_argument("--path", default="./assets/dog.jpg", help="path to images or video")"""输入验证或者测试图片及视频的路径"""
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")"""调用摄像头的索引"""
parser.add_argument("--save_result", action="store_true", help="whether to save the inference result of image/video",)"""需不需要保存结果"""
# exp file
parser.add_argument("-f", "--exp_file", default=None, type=str, help="pls input your expriment description file",)"""试验描述文件"""
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")"""权重的地址"""
parser.add_argument("--device", default="cpu", type=str, help="device to run our model, can either be cpu or gpu",)"""实验设备"""
parser.add_argument("--conf", default=None, type=float, help="test conf")"""置信度阈值"""
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")"""非极大值抑制阈值"""
parser.add_argument("--tsize", default=None, type=int, help="test img size")"""输入网络图片尺寸的大小,默认为32的倍数。因为YOLOX最大要进行32倍的降采样"""
parser.add_argument("--fp16", dest="fp16", default=False, action="store_true", help="Adopting mix precision evaluating.",)"""混合精度训练。一般情况都用fp32。fp16会减小计算量,加快训练速度"""
parser.add_argument("--fuse", dest="fuse", default=False, action="store_true", help="Fuse conv and bn for testing.",)"""模型融合高级技巧。将conv层和bn层融合,加快推理速度"""
parser.add_argument("--trt", dest="trt", default=False, action="store_true", help="Using TensorRT model for testing.",)"""用trt模式加快推理速度"""
return parser
当我们运行demo.py的时候,首先会执行下面两行代码:
args = make_parser().parse_args()"""解析参数"""
exp = get_exp(args.exp_file, args.name)
我们定位get_exp()方法,在yolox\exp\build.py文件夹中,首先会执行断言函数,如果两个参数都为空,则会报错,计算机不知道我们要干什么。接着执行get_exp_by_name()方法:
def get_exp(exp_file, exp_name):
"""
get Exp object by file or name. If exp_file and exp_name
are both provided, get Exp by exp_file.
Args:
exp_file (str): file path of experiment.
exp_name (str): name of experiment. "yolo-s",
"""
assert (exp_file is not None or exp_name is not None), "plz provide exp file or exp name."
if exp_file is not None:
return get_exp_by_file(exp_file)
else:
return get_exp_by_name(exp_name)
我们进入当前脚本的get_exp_by_name()方法:
def get_exp_by_name(exp_name):
import yolox
yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))"""获取yolox的根目录"""
filedict = {
"yolox-s": "yolox_s.py",
"yolox-m": "yolox_m.py",
"yolox-l": "yolox_l.py",
"yolox-x": "yolox_x.py",
"yolox-tiny": "yolox_tiny.py",
"yolox-nano": "nano.py",
"yolov3": "yolov3.py",}
filename = filedict[exp_name]"""模型名称对应的py文件"""
exp_path = os.path.join(yolox_path, "exps", "default", filename)"""路径拼接:根目录/exps/default/"""
return get_exp_by_file(exp_path)
然后执行get_exp_by_file()方法:
def get_exp_by_file(exp_file):
try:"""异常处理函数"""
"""下面全部代码的作用就是import的一个yolox-m.py"""
sys.path.append(os.path.dirname(exp_file))"""sys.path保存路径"""
current_exp = importlib.import_module(os.path.basename(exp_file).split(".")[0])
exp = current_exp.Exp()"""yolox-m.py中的Exp类"""
except Exception:
raise ImportError("{} doesn't contains class named 'Exp'".format(exp_file))
return exp
我们进入到exps\default\yolox_s.py,其父类的位置为yolox\exp\yolox_base.py。:
import os
from yolox.exp import Exp as MyExp
class Exp(MyExp):
def __init__(self):
super(Exp, self).__init__()
self.depth = 0.33"""调整模型深度"""
self.width = 0.50"""调整模型宽度"""
"""os.path.realpath(__file__)代表获取当前模块的绝对路径"""
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]"""值为yolox_s"""
接着我们执行demo.py的main()方法,整个过程十分简单,最重要的就是其内部的Predictor()方法:
def main(exp, args):
if not args.experiment_name:"""如果没有输入试验名称,就以exp.exp_name(如:yolox_s)为实验名称"""
args.experiment_name = exp.exp_name
file_name = os.path.join(exp.output_dir, args.experiment_name)"""保存训练结果的位置"""
os.makedirs(file_name, exist_ok=True)"""路径不存在,则创建路径"""
if args.save_result:
vis_folder = os.path.join(file_name, "vis_res")"""再来一个子路径,也是用来保存结果信息"""
os.makedirs(vis_folder, exist_ok=True)
if args.trt:"""不管"""
args.device = "gpu"
logger.info("Args: {}".format(args))"""记载日志"""
if args.conf is not None:
exp.test_conf = args.conf
if args.nms is not None:
exp.nmsthre = args.nms
if args.tsize is not None:
exp.test_size = (args.tsize, args.tsize)
model = exp.get_model()"""模型加载"""
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
if args.device == "gpu":"""如果设备是gpu,还得cuda一下"""
model.cuda()
model.eval()"""测试模式,不会改变任何权重"""
if not args.trt:
if args.ckpt is None:
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
logger.info("loading checkpoint")
ckpt = torch.load(ckpt_file, map_location="cpu")"""将权重文件加载进内存"""
# load the model state dict
model.load_state_dict(ckpt["model"])"""将内存的权重值加载到模型"""
logger.info("loaded checkpoint done.")
if args.fuse:
logger.info("\tFusing model...")
model = fuse_model(model)
if args.trt:
assert not args.fuse, "TensorRT model is not support model fusing!"
trt_file = os.path.join(file_name, "model_trt.pth")
assert os.path.exists(
trt_file
), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
model.head.decode_in_inference = False
decoder = model.head.decode_outputs
logger.info("Using TensorRT to inference")
else:
trt_file = None
decoder = None
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device)
我们进入Predictor()方法,只需好好理解以下两个方法,其他的不用管:
def inference(self, img):"""推理出检测框"""
img_info = {"id": 0}
if isinstance(img, str):"""判断输入是否为字符串"""
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)"""读图"""
else:
img_info["file_name"] = None
height, width = img.shape[:2]"""得到图像的宽高"""
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)"""改变图像的shape为我们规定的shape(等比例缩放、其他地方填充灰色、并将图像归一化)"""
img_info["ratio"] = ratio"""获取图像的缩放比例"""
img = torch.from_numpy(img).unsqueeze(0)"""将数组形式转为张量,并扩充维度以满足运算要求"""
if self.device == "gpu":
img = img.cuda()"""如果为gpu,还得将图像cuda一下"""
with torch.no_grad():"""测试过程不进行梯度运算"""
t0 = time.time()
outputs = self.model(img)"""output一般在几千个左右"""
if self.decoder is not None:"""高级训练技巧,不用管"""
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)"""非极大值抑制"""
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info
def visual(self, output, img_info, cls_conf=0.35):"""可视化检测框"""
ratio = img_info["ratio"]"""获取缩放比例"""
img = img_info["raw_img"]"""获取原始图像"""
if output is None:"""如果没有检测框"""
return img
output = output.cpu()"""cpu一下"""
bboxes = output[:, 0:4]"""获取预测框"""
# preprocessing: resize
bboxes /= ratio"""将预测框还原成原图大小"""
cls = output[:, 6]"""获取类别"""
scores = output[:, 4] * output[:, 5]"""获取置信度"""
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)"""画图,率除掉置信度阈值以下的框"""
return vis_res