导入训练好的模型,Pytorch

import cv2
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.nn as nn

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5], std=[0.5])])

data_test = datasets.MNIST(root="./data/",
                           transform=transform,
                           train=False)

data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size=64,
                                               shuffle=True,
                                               drop_last=True)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 使用序列工具快速构建
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7 * 7 * 32, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = out.view(out.size(0), -1)  # reshape
        out = self.fc(out)
        return out


model = CNN()
check = torch.load("cnn.pth")
model.load_state_dict(check)
correct = 0
for data in data_loader_test:
    X_test,y_test = data
    inputs = Variable(X_test)
    output = model(inputs)
    _, pred = torch.max(output, 1)
    correct += torch.sum(pred == y_test.data)
    print(correct)
    correct = 0


print("Successfully load")
img = torchvision.utils.make_grid(X_test)
img = img.cpu().numpy().transpose(1, 2, 0)
print("Predict Label is:", [i for i in pred])
print("Real Label is :", [i for i in y_test])
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

代码借鉴了深度学习之PyTorch实战(3)——实战手写数字识别
有兴趣的可以去看看原博客,我主要是想讲一下自己遇到的问题。
这部分代码主要是对训练好的模型进行测试,其中最关键的问题就是其中的CNN模型,首先一定要与训练模型中的类结构一致,然后就是加载模型

model = CNN()
check = torch.load("cnn.pth")
model.load_state_dict(check)

其中最重要的就是

model = CNN()

我刚开始一直出现
‘collections.OrderedDict’ object is not callable
或者
RuntimeError: Error(s) in loading state_ dict for Model :
主要是我对这个CNN类没有与训练模型保持一致,所以导致出现各种问题。

你可能感兴趣的:(导入训练好的模型,Pytorch)