Pytorch学习笔记(模型验证/测试)

模型验证

模型验证(测试,demo):利用已经训练好的模型,然后给它提供输入进行测试验证。

import torch
import torchvision.transforms
from PIL import Image
from torch import nn

# 需要测试的图片
image_path = "../imgs/airplane.png"
image = Image.open(image_path)
image = image.convert('RGB')  # png图片多了一个透明度通道,修改成rgb三个通道
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)


# 引入网络架构
class NNN(nn.Module):
    def __init__(self):
        super(NNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 32, 5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, 5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x


# 读取网络模型  如果保存的模型是通过gpu训练出来的,需要添加 map_location=torch.device("cpu")
model_load = torch.load("NNN_5.pth", map_location=torch.device("cpu"))
# 原有的图片是没有bitch-size的,而我们的输入是需要的
image = torch.reshape(image, (1, 3, 32, 32))
model_load.eval()
with torch.no_grad():
    outputs = model_load(image)
print(outputs)

print(outputs.argmax(1))
  1. 找一张 你需要用训练出来的模型进行测试的图片
  2. 读取加载你保存的训练模型【用gpu训练的,要加上map_location,不然会报错】
  3. 把图片输入模型进行验证【注意输入图片的格式要求】
  4. 输出预测结果outputs.argmax(1)

你可能感兴趣的:(pytorch,pytorch,model,test)