【pytorch】使用torchvision进行语义分割

面向初学者的PyTorch:使用torchvision进行语义分割

  • 1.加载模型
  • 2.加载并显示图像
  • 3.图像预处理
  • 4.Forward pass through the network
  • 5.输出
  • 6.Final Result


使用已经在COCO Train 2017数据集的子集上进行训练的FCN,该子集对应于PASCALVOC数据集。模型共支持20个类别。

1.加载模型

from torchvision import models             #加载模型
fcn = models.segmentation.fcn_resnet101(pretrained=True).eval()    #基于Resnet101的预先训练的FCN模型。

'''
如果模型尚未存在于缓存中,则pretrained=True标志将下载该模型。
该.val方法将以推理模式加载它
'''

.eval用法

2.加载并显示图像

from PIL import Image
import matplotlib.pyplot as plt
import torch
 
img = Image.open('C:/Users/ting/Desktop./qcs.jpg')
plt.imshow(img); plt.show()

【pytorch】使用torchvision进行语义分割_第1张图片

3.图像预处理

为了将图像准备为使用模型进行推断的正确格式,我们需要对其进行预处理并规范化!
因此,对于预处理步骤,我们执行以下步骤。

  • Resize the image to (256 x 256)
  • CenterCrop it to (224 x 224)
  • Convert it to Tensor – all the values in the image will be scaled so they lie between instead of the original, range. [0,1] [0, 255]
  • Normalize it with the Imagenet specific values where mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
  • And lastly, we unsqueeze the image dimensions so that it becomes from . This is required since we need a batch while passing it through the network . [1 x C x H x W] [C x H x W]
# Apply the transformations needed
import torchvision.transforms as T
trf = T.Compose([T.Resize(256),          #将图像尺寸调整为256×256
                 T.CenterCrop(224),         #中心裁剪,大小为224x224
                 T.ToTensor(),              #将图像转换为张量,并将值缩放到[0,1]范围
                 T.Normalize(mean = [0.485, 0.456, 0.406], 
                             std = [0.229, 0.224, 0.225])])  #用给定的均值和标准差对图像进行正则化。
inp = trf(img).unsqueeze(0)
print(inp)
tensor([[[[ 1.4098,  1.3584,  1.2899,  ...,  0.5364,  0.5364,  0.5364],
          [ 1.4612,  1.4098,  1.3413,  ...,  0.5536,  0.5536,  0.5536],
          [ 1.4954,  1.4440,  1.3755,  ...,  0.5536,  0.5536,  0.5536],
          ...,
          [-1.6555, -1.6555, -1.6898,  ..., -2.0494, -2.0152, -2.0152],
          [-1.5357, -1.5699, -1.6042,  ..., -2.0494, -2.0152, -2.0152],
          [-1.5185, -1.5357, -1.5699,  ..., -2.0494, -2.0152, -2.0152]],

         [[ 1.4132,  1.3606,  1.2906,  ...,  0.7129,  0.7129,  0.7129],
          [ 1.4657,  1.4132,  1.3431,  ...,  0.7304,  0.7304,  0.7304],
          [ 1.5007,  1.4482,  1.3782,  ...,  0.7304,  0.7304,  0.7304],
          ...,
          [-1.5455, -1.5455, -1.5805,  ..., -1.9657, -1.9657, -1.9657],
          [-1.4230, -1.4580, -1.4930,  ..., -1.9657, -1.9657, -1.9657],
          [-1.4055, -1.4230, -1.4580,  ..., -1.9657, -1.9657, -1.9657]],

         [[ 1.9254,  1.8731,  1.8034,  ...,  1.1934,  1.1934,  1.1934],
          [ 1.9777,  1.9254,  1.8557,  ...,  1.2108,  1.2108,  1.2108],
          [ 2.0125,  1.9603,  1.8905,  ...,  1.2108,  1.2108,  1.2108],
          ...,
          [-1.2293, -1.2293, -1.2641,  ..., -1.6999, -1.6824, -1.6824],
          [-1.1073, -1.1421, -1.1770,  ..., -1.6999, -1.6824, -1.6824],
          [-1.0898, -1.1073, -1.1421,  ..., -1.6999, -1.6824, -1.6824]]]])

4.Forward pass through the network

# Pass the input through the net
out = fcn(inp)['out']  #out是模型的最终输出。
print (out.shape)
torch.Size([1, 21, 224, 224])

out是模型的最终输出, [1 x 21 x H x W] 。

我们需要将这21个通道输出到一个2D图像或一个1通道图像,其中该图像的每个像素对应于一个类! 因此,2D图像(形状[HxW])的每个像素将与相应的类标签对应,对于该2D图像中的每个(x,y)像素将对应于表示类的0-20之间的数字。 How?我们为每个像素位置取一个最大索引,它表示类。

import numpy as np
om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
print (om.shape)
(224, 224)

out.squeeze()去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。此处对out进行降维,变为[21 x H x W] 。

torch.argmax() 返回指定维度最大值的序号。dim给定的定义是:the demention to reduce.也就是把dim这个维度的,变成这个维度的最大值的index。dim=0说明对第一维(21)操作,故结果为[H x W] 。

降维torch.squeeze(input, dim=None, out=None)
torch.argmax() 函数详解

print (np.unique(om))#np.unique() 函数 去除其中重复的元素 ,并按元素 由小到大 返回一个新的无元素重复的元组或者列表。
[ 0 15]

处理完的列表中共有2种元素,0(背景),15(人)。正如我们所看到的,现在我们有了一个2D图像,其中每个像素属于一个类。最后一件事是把这个2D图像转换成一个分割图像,每个类标签对应于一个RGB颜色,从而使图像易于观看。

5.输出

将此2D图像转换为RGB图像,其中每个(元素)标签映射到相应的颜色。

首先,列表label_colors根据索引存储每个类的颜色。因此,第一类的颜色是背景,存储在label_colors列表的第0个索引处。第二类,即飞机,存储在索引1中,以此类推。
现在,我们必须从我们拥有的2D图像中创建一个RGB图像。因此,我们所做的是为所有3个通道创建空的2D矩阵。 因此,r、g和b是构成最终图像的RGB通道的列表,这些列表中的每一个的形状都是[HxW](这与2D图像的形状相同)。

现在,我们循环存储在label_colors中的每个颜色,并在存在特定类标签的2D图像中获取索引。然后,对于每个通道,我们将其相应的颜色放置到存在该类标签的像素上。 最后,我们将3个独立的通道叠加起来,形成RGB图像。 好吧!现在,让我们使用这个函数来查看最终的输出!

# Define the helper function
def decode_segmap(image, nc=21):
   
  label_colors = np.array([(0, 0, 0),  # 0=background
               # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
               (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
               # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
               (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
               # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
               (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
               # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
               (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
 
  r = np.zeros_like(image).astype(np.uint8)
  g = np.zeros_like(image).astype(np.uint8)
  b = np.zeros_like(image).astype(np.uint8)
   
  for l in range(0, nc):
    idx = image == l
    r[idx] = label_colors[l, 0]
    g[idx] = label_colors[l, 1]
    b[idx] = label_colors[l, 2]
     
  rgb = np.stack([r, g, b], axis=2)
  return rgb

rgb = decode_segmap(om)
plt.imshow(rgb); plt.show()

【pytorch】使用torchvision进行语义分割_第2张图片

6.Final Result

接下来,让我们把所有操作放入一个函数下,并测试更多的图像!

def segment(net, path):
  img = Image.open(path)
  plt.imshow(img); plt.axis('off'); plt.show()
  # Comment the Resize and CenterCrop for better inference results
  trf = T.Compose([T.Resize(256), 
                   T.CenterCrop(224), 
                   T.ToTensor(), 
                   T.Normalize(mean = [0.485, 0.456, 0.406], 
                               std = [0.229, 0.224, 0.225])])
  inp = trf(img).unsqueeze(0)
  out = net(inp)['out']
  om = torch.argmax(out.squeeze(), dim=0).detach().cpu().numpy()
  rgb = decode_segmap(om)
  plt.imshow(rgb); plt.axis('off'); plt.show()


print (segment(fcn ,'C:/Users/ting/Desktop./qcs.jpg'))

参考文献


你可能感兴趣的:(pytorch,jupyter,notebook,pytorch,python,深度学习)