在练习使用pytorch加载模型,识别图片时,出现了这一问题。
解决方法:使用torch.reshape()将输入数据格式改成与网络相符的格式。
报错代码:
import torch
import torchvision
from PIL import Image
from model import *
image = Image.open("./img/dog.png")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(image)
model = MyNet()
model.load_state_dict(torch.load("./testmodel/mynet_7.pth"))
output = model(image)
print(output)
其中我的model.py文件中的代码:
#!/usr/bin/env python
# _*_ coding: utf-8 _*_
# @Time : 2023-09-22 15:57
# @Author : Kanbara
# @File : model.py
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, Flatten
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.model = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(64*4*4, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
#测试该文件是否编译有问题
if __name__ == '__main__':
mynet = MyNet()
input = torch.ones([64, 3, 32, 32])
output = mynet(input)
print(output.shape)
model中的网络经过测试,本身不存在问题。
实际上,是输入图像尺寸少了一个参数batch_size导致。
print(image.shape)
>>torch.Size([3, 32, 32])
而根据网络设置,输入应有四个维度,第一个维度为batch_size。通过torch.reshape功能,添加代码:
image = torch.reshape(image, (1, 3, 32, 32))
即可解决。此时代码能够正常运行,修改后代码:
import torch
import torchvision
from PIL import Image
from model import *
image = Image.open("./img/dog.png")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(image)
model = MyNet()
model.load_state_dict(torch.load("./testmodel/mynet_7.pth"))
image = torch.reshape(image, (1, 3, 32, 32))
output = model(image)
print(output)