使用已经在COCO Train 2017数据集的子集上进行训练的FCN,该子集对应于PASCALVOC数据集。模型共支持20个类别。
from torchvision import models #加载模型
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval() #基于Resnet101的预先训练的FCN模型。
'''
如果模型尚未存在于缓存中,则pretrained=True标志将下载该模型。
该.val方法将以推理模式加载它
'''
.eval用法
from PIL import Image
import matplotlib.pyplot as plt
import torch
img = Image.open('C:/Users/ting/Desktop./qcs.jpg')
plt.imshow(img); plt.show()
为了将图像准备为使用模型进行推断的正确格式,我们需要对其进行预处理并规范化!
因此,对于预处理步骤,我们执行以下步骤。
# Apply the transformations needed
import torchvision.transforms as T
trf = T.Compose([T.Resize(256), #将图像尺寸调整为256×256
T.CenterCrop(224), #中心裁剪,大小为224x224
T.ToTensor(), #将图像转换为张量,并将值缩放到[0,1]范围
T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])]) #用给定的均值和标准差对图像进行正则化。
inp = trf(img).unsqueeze(0)
print(inp)
tensor([[[[ 1.4098, 1.3584, 1.2899, ..., 0.5364, 0.5364, 0.5364],
[ 1.4612, 1.4098, 1.3413, ..., 0.5536, 0.5536, 0.5536],
[ 1.4954, 1.4440, 1.3755, ..., 0.5536, 0.5536, 0.5536],
...,
[-1.6555, -1.6555, -1.6898, ..., -2.0494, -2.0152, -2.0152],
[-1.5357, -1.5699, -1.6042, ..., -2.0494, -2.0152, -2.0152],
[-1.5185, -1.5357, -1.5699, ..., -2.0494, -2.0152, -2.0152]],
[[ 1.4132, 1.3606, 1.2906, ..., 0.7129, 0.7129, 0.7129],
[ 1.4657, 1.4132, 1.3431, ..., 0.7304, 0.7304, 0.7304],
[ 1.5007, 1.4482, 1.3782, ..., 0.7304, 0.7304, 0.7304],
...,
[-1.5455, -1.5455, -1.5805, ..., -1.9657, -1.9657, -1.9657],
[-1.4230, -1.4580, -1.4930, ..., -1.9657, -1.9657, -1.9657],
[-1.4055, -1.4230, -1.4580, ..., -1.9657, -1.9657, -1.9657]],
[[ 1.9254, 1.8731, 1.8034, ..., 1.1934, 1.1934, 1.1934],
[ 1.9777, 1.9254, 1.8557, ..., 1.2108, 1.2108, 1.2108],
[ 2.0125, 1.9603, 1.8905, ..., 1.2108, 1.2108, 1.2108],
...,
[-1.2293, -1.2293, -1.2641, ..., -1.6999, -1.6824, -1.6824],
[-1.1073, -1.1421, -1.1770, ..., -1.6999, -1.6824, -1.6824],
[-1.0898, -1.1073, -1.1421, ..., -1.6999, -1.6824, -1.6824]]]])
# Pass the input through the net
out = fcn(inp)['out'] #out是模型的最终输出。
print (out.shape)
torch.Size([1, 21, 224, 224])
out是模型的最终输出, [1 x 21 x H x W] 。
我们需要将这21个通道输出到一个2D图像或一个1通道图像,其中该图像的每个像素对应于一个类! 因此,2D图像(形状[HxW])的每个像素将与相应的类标签对应,对于该2D图像中的每个(x,y)像素将对应于表示类的0-20之间的数字。 How?我们为每个像素位置取一个最大索引,它表示类。
import numpy as np
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
print (om.shape)
(224, 224)
out.squeeze()去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。此处对out进行降维,变为[21 x H x W] 。
torch.argmax() 返回指定维度最大值的序号。dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。dim=0说明对第一维(21)操作,故结果为[H x W] 。
降维torch.squeeze(input, dim=None, out=None)
torch.argmax() 函数详解
print (np.unique(om))#np.unique() 函数 去除其中重复的元素 ,并按元素 由小到大 返回一个新的无元素重复的元组或者列表。
[ 0 15]
处理完的列表中共有2种元素,0(背景),15(人)。正如我们所看到的,现在我们有了一个2D图像,其中每个像素属于一个类。最后一件事是把这个2D图像转换成一个分割图像,每个类标签对应于一个RGB颜色,从而使图像易于观看。
将此2D图像转换为RGB图像,其中每个(元素)标签映射到相应的颜色。
首先,列表label_colors根据索引存储每个类的颜色。因此,第一类的颜色是背景,存储在label_colors列表的第0个索引处。第二类,即飞机,存储在索引1中,以此类推。
现在,我们必须从我们拥有的2D图像中创建一个RGB图像。因此,我们所做的是为所有3个通道创建空的2D矩阵。 因此,r、g和b是构成最终图像的RGB通道的列表,这些列表中的每一个的形状都是[HxW](这与2D图像的形状相同)。
现在,我们循环存储在label_colors中的每个颜色,并在存在特定类标签的2D图像中获取索引。然后,对于每个通道,我们将其相应的颜色放置到存在该类标签的像素上。 最后,我们将3个独立的通道叠加起来,形成RGB图像。 好吧!现在,让我们使用这个函数来查看最终的输出!
# Define the helper function
def decode_segmap(image, nc=21):
label_colors = np.array([(0, 0, 0), # 0=background
# 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
(128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
# 6=bus, 7=car, 8=cat, 9=chair, 10=cow
(0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
# 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
(192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
# 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
(0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, nc):
idx = image == l
r[idx] = label_colors[l, 0]
g[idx] = label_colors[l, 1]
b[idx] = label_colors[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
rgb = decode_segmap(om)
plt.imshow(rgb); plt.show()
接下来,让我们把所有操作放入一个函数下,并测试更多的图像!
def segment(net, path):
img = Image.open(path)
plt.imshow(img); plt.axis('off'); plt.show()
# Comment the Resize and CenterCrop for better inference results
trf = T.Compose([T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])])
inp = trf(img).unsqueeze(0)
out = net(inp)['out']
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
rgb = decode_segmap(om)
plt.imshow(rgb); plt.axis('off'); plt.show()
print (segment(fcn ,'C:/Users/ting/Desktop./qcs.jpg'))
参考文献