我们的任务是,把pytorch的torchvision.models 模块中现有的分类模型包括权重,拿来。然后用一张图片去预测类别。
提前去下载好squeezenet1_1-f364aa15.pth,放到自己pytorch的安装路径里,比如我的是下面的pthfile
下载方式:https://github.com/pytorch/vision/tree/master/torchvision/models
找到对应模型名称点进去找地址
import torch
import torchvision.models as models
pthfile = r'G:\Anaconda3\envs\pytorch\Lib\site-packages\torchvision\models\squeezenet1_1-f364aa15.pth'
net = torch.load(pthfile)
print(net)
这是权重文件
OrderedDict([(‘features.0.weight’, tensor([[[[ 1.2094e-01, 1.7803e-01, -1.5971e-02],
[ 2.6995e-01, 3.4009e-01, 7.6897e-02],
[ 1.3524e-01, 1.5867e-01, -2.9714e-02]],
[[-3.1629e-01, -3.5250e-01, -2.3707e-01],
[-4.1020e-01, -4.3460e-01, -2.9274e-01],
[-3.4065e-01, -3.6389e-01, -2.3710e-01]],
[[ 1.6492e-01, 1.7115e-01, 2.4033e-01],
[ 1.5148e-01, 9.3240e-02, 1.9888e-01],
[ 2.4104e-01, 2.0141e-01, 2.8289e-01]]],
需要把它加载到对应的模型中
net = models.squeezenet1_1(pretrained=False)
pthfile = r'G:\Anaconda3\envs\pytorch\Lib\site-packages\torchvision\models\squeezenet1_1-f364aa15.pth'
net.load_state_dict(torch.load(pthfile))
print(net)
这样就有模型,有训练好的权重了。
SqueezeNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
(3): Fire(
(squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
(squeeze_activation): ReLU(inplace=True)
(expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
(expand1x1_activation): ReLU(inplace=True)
(expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(expand3x3_activation): ReLU(inplace=True)
)
注意:这里的权重是在Imagenet1000万的数据集上训练的,有1000个类
下面,我们传张图片进去看看,能预测出啥结果
图片如下
from PIL import Image
img = Image.open("liya.jpg")
img
塞进模型前,需要对图片格式做处理
from torchvision import transforms
transform = transforms.Compose([ #[1]
transforms.Resize(256), #[2]
transforms.CenterCrop(224), #[3]
transforms.ToTensor(), #[4]
transforms.Normalize( #[5]
mean=[0.485, 0.456, 0.406], #[6]
std=[0.229, 0.224, 0.225] #[7]
)])
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
进入模型的eval()模式
net.eval()
out = net(batch_t)
print(out.shape)
输出为一个二维torch,一张图片,对应有1000个类别的预测数据
torch.Size([1, 1000])
看看这些数据长什么样子
out
tensor([[ 8.4795, 4.4297, 3.3579, 3.8392, 10.8696, 5.6860, 10.6676, 5.7791,
5.3900, 3.8367, 4.2046, 4.0783, 5.1175, 7.4449, 8.3978, 4.0603,
3.3622, 7.2784, 6.4025, 8.6113, 4.8858, 3.6329, 3.9229, 3.3437,
4.3338, 5.2695, 7.7577, 8.1239, 8.1069, 6.5631, 8.8629, 6.2040,
7.4533, 11.0477, 10.5223, 8.9174, 10.3797, 9.1047, 7.8688, 11.0185,
8.9409, 9.6788, 7.4256, 10.1015, 10.7663, 6.5917, 9.4496, 8.9061,
很明显,不少概率值,我们需要对他们进行SOFTMAX
但,首先,我们需要有个类别的标签名字,不然,怎么知道预测的是猫还是人。把IMAGENET1000标签https://blog.csdn.net/weixin_34304013/article/details/93708121复制到一个空的TXT里,去掉最外面的{}
打开看看
with open('imagenet_classes.txt') as f:
classes = [line.strip() for line in f.readlines()]
[“0: ‘tench, Tinca tinca’,”,
“1: ‘goldfish, Carassius auratus’,”,
“2: ‘great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias’,”,
“3: ‘tiger shark, Galeocerdo cuvieri’,”,
“4: ‘hammerhead, hammerhead shark’,”,
“5: ‘electric ray, crampfish, numbfish, torpedo’,”,
“6: ‘stingray’,”,
“7: ‘cock’,”,
“8: ‘hen’,”,
“9: ‘ostrich, Struthio camelus’,”,
“10: ‘brambling, Fringilla montifringilla’,”,
“11: ‘goldfinch, Carduelis carduelis’,”,
好,我们开始索引最大预测值的位置
_, index = torch.max(out, 1)
index
tensor([638])
如果是一个batch=12,就是有12个值了。我们只有一个值,取第一个。
index[0]
tensor(638)
前面提到,把预测值变为概率值
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
tensor([6.2143e-05, 1.0829e-06, 3.7080e-07, 5.9999e-07, 6.7829e-04, 3.8039e-06,
5.5422e-04, 4.1748e-06, 2.8290e-06, 5.9853e-07, 8.6468e-07, 7.6210e-07,
2.1543e-06, 2.2085e-05, 5.7270e-05, 7.4849e-07, 3.7240e-07, 1.8697e-05,
7.7869e-06, 7.0898e-05, 1.7088e-06, 4.8817e-07, 6.5240e-07, 3.6557e-07,
9.8389e-07, 2.5079e-06, 3.0194e-05, 4.3548e-05, 4.2816e-05, 9.1437e-06,
9.1187e-05, 6.3855e-06, 2.2272e-05, 8.1050e-04, 4.7929e-04, 9.6296e-05,
4.1560e-04, 1.1613e-04, 3.3743e-05, 7.8723e-04, 9.8578e-05, 2.0618e-04,
print(classes[index[0]], percentage[index[0]].item())
638: ‘maillot’, 36.48484802246094
得到的是“女式紧身衣”
为啥要.item(),看看不这样
percentage[index[0]]
tensor(36.4848, grad_fn=)
看看top5的预测值
_, indices = torch.sort(out, descending=True)
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
[(“638: ‘maillot’,”, 36.48484802246094),
(“459: ‘brassiere, bra, bandeau’,”, 21.750808715820312),
(“639: ‘maillot, tank suit’,”, 12.231551170349121),
(“445: ‘bikini, two-piece’,”, 11.10608959197998),
(“578: ‘gown’,”, 5.4838056564331055)]