批量预测
import os
import json
import torch
from PIL import Image
from torchvision import transforms
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
# 指向需要遍历预测的图像文件夹
imgs_root = "/data/imgs"
assert os.path.exists(imgs_root), f"file: '{imgs_root}' dose not exist."
# 读取指定文件夹下所有jpg图像路径
img_path_list = [os.path.join(imgs_root, i) for i in os.listdir(imgs_root) if i.endswith(".jpg")]
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), f"file: '{json_path}' dose not exist."
json_file = open(json_path, "r")
class_indict = json.load(json_file)
# create model
model = resnet34(num_classes=5).to(device)
# load model weights
weights_path = "./resNet34.pth"
assert os.path.exists(weights_path), f"file: '{weights_path}' dose not exist."
model.load_state_dict(torch.load(weights_path, map_location=device))
# prediction
model.eval()
batch_size = 8 # 每次预测时将多少张图片打包成一个batch
with torch.no_grad():
for ids in range(0, len(img_path_list) // batch_size):
img_list = []
for img_path in img_path_list[ids * batch_size: (ids + 1) * batch_size]:
assert os.path.exists(img_path), f"file: '{img_path}' dose not exist."
img = Image.open(img_path)
img = data_transform(img)
img_list.append(img)
# batch img
# 将img_list列表中的所有图像打包成一个batch
batch_img = torch.stack(img_list, dim=0)
# predict class
output = model(batch_img.to(device)).cpu()
predict = torch.softmax(output, dim=1)
probs, classes = torch.max(predict, dim=1)
for idx, (pro, cla) in enumerate(zip(probs, classes)):
print("image: {} class: {} prob: {:.3}".format(img_path_list[ids * batch_size + idx],
class_indict[str(cla.numpy())],
pro.numpy()))
if __name__ == '__main__':
main()
#批量预测,但只显示五张图片
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import resnet50
def plot_class_preds(net,
images_dir: str,
transform,
num_plot: int = 5,
device="cpu"):
if not os.path.exists(images_dir):
print("not found {} path, ignore add figure.".format(images_dir))
return None
label_path = os.path.join(images_dir, "label.txt")
if not os.path.exists(label_path):
print("not found {} file, ignore add figure".format(label_path))
return None
# read class_indict
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "not found {}".format(json_label_path)
json_file = open(json_label_path, 'r')
# {"0": "daisy"}
flower_class = json.load(json_file)
# {"daisy": "0"}
class_indices = dict((v, k) for k, v in flower_class.items())
# reading label.txt file
label_info = []
with open(label_path, "r") as rd:
for line in rd.readlines():
line = line.strip()
if len(line) > 0:
split_info = [i for i in line.split(" ") if len(i) > 0]
assert len(split_info) == 2, "label format error, expect file_name and class_name"
image_name, class_name = split_info
image_path = os.path.join(images_dir, image_name)
# 如果文件不存在,则跳过
if not os.path.exists(image_path):
print("not found {}, skip.".format(image_path))
continue
# 如果读取的类别不在给定的类别内,则跳过
if class_name not in class_indices.keys():
print("unrecognized category {}, skip".format(class_name))
continue
label_info.append([image_path, class_name])
if len(label_info) == 0:
return None
# get first num_plot info
if len(label_info) > num_plot:
label_info = label_info[:num_plot]
num_imgs = len(label_info)
images = []
labels = []
for img_path, class_name in label_info:
# read img
img = Image.open(img_path).convert("RGB")
label_index = int(class_indices[class_name])
# preprocessing
img = transform(img)
images.append(img)
labels.append(label_index)
# batching images
images = torch.stack(images, dim=0).to(device)
# inference
with torch.no_grad():
output = net(images)
probs, preds = torch.max(torch.softmax(output, dim=1), dim=1)
probs = probs.cpu().numpy()
preds = preds.cpu().numpy()
# width, height
fig = plt.figure(figsize=(num_imgs * 2.5, 3), dpi=100)
for i in range(num_imgs):
# 1:子图共1行,num_imgs:子图共num_imgs列,当前绘制第i+1个子图
ax = fig.add_subplot(1, num_imgs, i+1, xticks=[], yticks=[])
# CHW -> HWC
npimg = images[i].cpu().numpy().transpose(1, 2, 0)
# 将图像还原至标准化之前
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
npimg = (npimg * std + mean) * 255
plt.imshow(npimg.astype('uint8'))
title = "{}, {:.2f}%\n(label: {})".format(
flower_class[str(preds[i])], # predict class
probs[i] * 100, # predict probability
flower_class[str(labels[i])] # true class
)
ax.set_title(title, color=("green" if preds[i] == labels[i] else "red"))
return fig
#单张
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img_path = "../tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = resnet34(num_classes=5).to(device)
# load model weights
weights_path = "./resNet34.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device))
# prediction
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
deep-learning-for-image-processing/predict.py at master · WZMIAOMIAO/deep-learning-for-image-processing · GitHub