import os import json import torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from model import efficientnet_b0 as create_model def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") img_size = {"B0": 224, "B1": 240, "B2": 260, "B3": 300, "B4": 380, "B5": 456, "B6": 528, "B7": 600} num_model = "B0" data_transform = transforms.Compose( [transforms.Resize(img_size[num_model]), transforms.CenterCrop(img_size[num_model]), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # 建立循环,将目标目录下的图片路径取出 # 需要改动的有两个部分,输入文件夹的路径、预测结果与文件夹对应结果的对比 path = r'C:\Users\sun\Desktop\b_classification\classification-pytorch-main\color_datasets\train\No-Anomaly' # 获取该路径下所有图片 filelist = os.listdir(path) a = 1 for files in filelist: # load image img_path = os.path.join(path, files) assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) json_file = open(json_path, "r") class_indict = json.load(json_file) # create model model = create_model(num_classes=5).to(device) # load model weights model_weight_path = "./weights/model-29.pth" model.load_state_dict(torch.load(model_weight_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() if predict_cla == "No-Anomaly": a = a + 1 else: a = a # 这里的result就是该路径下的图片得到的准确率 result = a / 2000 print(result) # print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], # predict[predict_cla].numpy()) # 不需要展示运行结果所以该部分注释掉 # plt.title(print_res) # for i in range(len(predict)): # print("class: {:10} prob: {:.3}".format(class_indict[str(i)], # predict[i].numpy())) # plt.show() if __name__ == '__main__': main()