onnx模型推理单张图片,网上的教程非常多,我自己以前也写了很多这些内容,但如何推理整个数据集来验证精度呢?
如果你只是为了验证导出的onnx模型精度如何,可以参考这篇文章。
为了保证模型前后处理完全一致,前后处理都直接复用原本的代码,输入输出数据涉及到tensor和numpy转换时直接用torch.from_numpy和.numpy实现。
到嵌入式开发板上跑的话,前后处理都是需要自己写的,而且无法依赖torch。
imagenet 验证集val,内部有1000个文件夹,每个文件夹下对应有50张图片。
pytorch默认使用PIL读取,刚读取的图片,像素顺序RGB,layout:NHWC
经过transforms.ToTensor()
,像素顺序RGB,layout:NCHW。当然,transforms.ToTensor()
还有数据归一化(除以255)的作用,具体细节可参考另一篇博客不使用torchvision.transforms 对图片预处理python实现。
主程序如下,主要修改该代码即可:
import torch
import torch.nn as nn
import sys
import os
import time
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import utils.common as utils # 下面给出代码
from tqdm import tqdm
class Data:
def __init__(self, data_path):
scale_size = 224
valdir = os.path.join(data_path, 'val')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
testset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Resize(scale_size),
transforms.ToTensor(),
normalize,
]))
self.loader_test = DataLoader(
testset,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=True)
def test_onnxruntime(ort_session, testLoader, logger, topk=(1,)):
accuracy = utils.AverageMeter('Acc@1', ':6.2f')
top5_accuracy = utils.AverageMeter('Acc@5', ':6.2f')
start_time = time.time()
testLoader = tqdm(testLoader, file=sys.stdout)
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testLoader):
inputs_origin = inputs
inputs, targets = inputs.numpy(), targets
ort_inputs = {"input1": inputs}
outputs = ort_session.run(None, ort_inputs)
outputs = torch.from_numpy(outputs[0])
predicted = utils.accuracy(outputs, targets, topk=topk)
accuracy.update(predicted[0], inputs_origin.size(0))
top5_accuracy.update(predicted[1], inputs_origin.size(0))
current_time = time.time()
logger.info(
'Test Top1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'
.format(float(accuracy.avg), float(top5_accuracy.avg), (current_time - start_time))
)
return top5_accuracy.avg, accuracy.avg
def onnx_inference_imagenet():
job_dir = './experiment'
logger = utils.get_logger(os.path.join(job_dir + 'logger.log'))
# Data
print('==> Preparing data..')
data_path = '/home/users/dataset/imagenet/'
# data_path = '/data/horizon_j5/data/imagenet/'
loader = Data(data_path)
testLoader = loader.loader_test
onnx_path = "./weights/resnet50/resnet50_pruned.onnx"
#---------------------------------------------------------#
# 使用onnxruntime
#---------------------------------------------------------#
import onnxruntime
ort_session = onnxruntime.InferenceSession(onnx_path)
#---------------------------------------------------------#
# 进test_onnxruntime函数
#---------------------------------------------------------#
test_onnxruntime(ort_session, testLoader, logger, topk=(1, 5))
if __name__ == '__main__':
onnx_inference_imagenet()
在utils文件夹下,有common.py
文件,其中代码如下:
import os
import sys
import shutil
import time, datetime
import logging
import numpy as np
from PIL import Image
from pathlib import Path
import torch
import torch.nn as nn
import torch.utils
'''record configurations'''
class record_config():
def __init__(self, args):
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
today = datetime.date.today()
self.args = args
self.job_dir = Path(args.job_dir)
def _make_dir(path):
if not os.path.exists(path):
os.makedirs(path)
_make_dir(self.job_dir)
config_dir = self.job_dir / 'config.txt'
#if not os.path.exists(config_dir):
if args.resume:
with open(config_dir, 'a') as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
else:
with open(config_dir, 'w') as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
def get_logger(file_path):
logger = logging.getLogger('gal')
log_format = '%(asctime)s | %(message)s'
formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
file_handler = logging.FileHandler(file_path)
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)
return logger
#label smooth
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def save_checkpoint(state, is_best, save):
if not os.path.exists(save):
os.makedirs(save)
filename = os.path.join(save, 'checkpoint.pth.tar')
torch.save(state, filename)
if is_best:
best_filename = os.path.join(save, 'model_best.pth.tar')
shutil.copyfile(filename, best_filename)
def adjust_learning_rate(optimizer, epoch, args):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def progress_bar(current, total, msg=None):
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f