目录
加载的训练数据是图片!!!
加载的训练数据是张量!!!
在模型训练完成后,可以通过绘制ROC曲线来评估模型的分类性能。ROC曲线是以真正率(True Positive Rate,TPR)为纵轴,假正率(False Positive Rate,FPR)为横轴的曲线,它反映了模型在不同阈值下的分类性能。
当你的数据是图片时,你可以使用torchvision.datasets.ImageFolder
来加载数据集,并在加载的同时使用torchvision.transforms
中的函数对数据进行预处理。
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import numpy as np
# 定义预处理函数
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图像大小为224x224
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化图像
])
# 加载测试集数据
test_data = torchvision.datasets.ImageFolder(root='test/', transform=transform)
# 定义模型
model = ...
# 加载模型参数
model.load_state_dict(torch.load('model.pth'))
# 设置模型为评估模式
model.eval()
# 预测标签和真实标签列表
preds, labels = [], []
# 遍历测试集数据,进行预测
for images, targets in test_data:
# 将图片数据添加一维作为batch
images = images.unsqueeze(0)
# 前向传播
outputs = model(images)
# 获取预测概率
prob = torch.nn.functional.softmax(outputs, dim=1)
# 获取预测标签
pred = prob.argmax(dim=1).item()
preds.append(prob[:, 1].item())
labels.append(targets)
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(labels, preds)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
1、引入模型定义 model = ,类似在头文件 import def_CNN ,model = def_CNN()
2、修改加载的模型文件路径,加载模型的形式为字典形式!
3、数据集加载路径,也需要修改!
4、其他细节问题,自行修改即可!
torch.load()
是一个用于从文件中加载PyTorch张量或模型的函数,它可以将保存在磁盘上的模型或张量加载到内存中。通常情况下,我们使用torchvision.datasets
中的数据集对象来加载数据,例如MNIST、CIFAR-10等。但是,如果您已经将数据集以张量的形式保存在磁盘上,那么您可以使用torch.load()
来加载数据。
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
# 导入模型和测试数据
model = torch.load('model.pth')
test_data = torch.load('test_data.pth')
# 设置模型为评估模式
model.eval()
# 预测测试数据的标签和分数
y_pred = []
y_score = []
with torch.no_grad():
for inputs, labels in test_data:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
score = outputs[:, 1].cpu().numpy() # 获取正类的预测分数
y_pred += predicted.tolist()
y_score += score.tolist()
y_true = np.array(test_data.dataset.targets)
# 计算FPR和TPR
fpr, tpr, _ = roc_curve(y_true, y_score)
# 计算ROC曲线下面积(AUC)
roc_auc = auc(fpr, tpr)
# 绘制ROC曲线
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
具体修改内容参考上面的!