pytorch预测结果的精度、召回率、F1值、准确率计算(简单实现)

1. 导入必要的库函数

import os
import json
import torch
from PIL import Image
from torchvision import transforms
# 导入自己的模型
from model_v3 import mobilenet_v3_small

2. 文件存储格式

imgs文件下下是需要预测的图像

label.txt为每幅图对应的标签

### 这个无所谓,能正常读取就行

pytorch预测结果的精度、召回率、F1值、准确率计算(简单实现)_第1张图片

 3. 标签读取辅助函数

targets_path = "test/label.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()

 4. 精度、召回率、F1值、准确率计算函数

# 获取标签
target = int(targets[i])
# 输出结果
output = torch.squeeze(model(img.to(device))).cpu()
# 取出最大值的索引
predict = torch.softmax(output, dim=0)
# 得到数值
predict_cla = torch.argmax(predict).item()
# 0为正类, 1为负类
if predict_cla == 0 and target == 0:
    TP += 1
if predict_cla == 1 and target == 1:
    TN += 1
if predict_cla == 0 and target == 1:
    FP += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
if predict_cla == 1 and target == 0:
    FN += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")

# 根据公式写的,如果有问题欢迎大家支持,一起学习
P = TP / (TP + FP + esp)
R = TP / (TP + FN + esp)
F1 = 2 * P * R / (P + R + esp)
acc = (TP + TN) / (TP + TN + FP + FN + esp)

5. 完整实现,修改路径可以用于自己的模型

import os
import json
import torch
from PIL import Image
from torchvision import transforms
from model_v3 import mobilenet_v3_small


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


    imgs_path = "I:/ZTC950V763_211118/CV/imgs/"
    targets_path = "I:/ZTC950V763_211118/CV/y_CV.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()
    TP, TN, FP, FN = 0, 0, 0, 0
    esp = 1e-6
    i = 0
    for _, __, files in os.walk(imgs_path):
        for file in files:
            img = Image.open(imgs_path + str(file))
            target = int(targets[i])
            # plt.imshow(img)
            # [N, C, H, W]
            img = data_transform(img)
            # expand batch dimension
            img = torch.unsqueeze(img, dim=0)
            # read class_indict
            json_path = 'class_indices.json'
            assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

            json_file = open(json_path, "r")
            class_indict = json.load(json_file)

            # create model
            model = mobilenet_v3_small(num_classes=2).to(device)
            # load model weights
            model_weight_path = "C:/Users/00769111/PycharmProjects/mobilenet_juanyang/weights/No_freeze_MobileNetV3.pth"
            model.load_state_dict(torch.load(model_weight_path, map_location=device))
            model.eval()
            with torch.no_grad():
                # predict class
                output = torch.squeeze(model(img.to(device))).cpu()
                predict = torch.softmax(output, dim=0)
                predict_cla = torch.argmax(predict).item()
                if predict_cla == 0 and target == 0:
                    TP += 1
                if predict_cla == 1 and target == 1:
                    TN += 1
                if predict_cla == 0 and target == 1:
                    FP += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
                if predict_cla == 1 and target == 0:
                    FN += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
            i += 1
            if i % 200 == 0:
                P = TP / (TP + FP + esp)
                R = TP / (TP + FN + esp)
                F1 = 2 * P * R / (P + R + esp)
                acc = (TP + TN) / (TP + TN + FP + FN + esp)
                print(f"精度为: {P}\n")
                print(f"召回率为: {R}\n")
                print(f"F1值为: {F1}\n")
                print(f"准确率为: {acc}")
        P = TP / (TP + FP + esp)
        R = TP / (TP + FN + esp)
        F1 = 2 * P * R / (P + R + esp)
        acc = (TP + TN) / (TP + TN + FP + FN + esp)
        print("结果汇总\n")
        print(f"精度为: {P}\n")
        print(f"召回率为: {R}\n")
        print(f"F1值为: {F1}\n")
        print(f"准确率为: {acc}")


if __name__ == '__main__':
    main()

你可能感兴趣的:(Pytorch,pytorch,深度学习,计算机视觉)