语义分割是对图像在像素级别上进行分类的方法,在一张图像中,属于同一类的像素点都要被预测为相同的类。因此语义分割是从像素级别来理解图像。
注意,语义分割仅仅是把某一类划分出来,而针对每个个体没办法进行分割(实例分割)。
常见的语义分割网络有很多,如FCN、U-Net、SegNet、DeepLab等。
FCN(Fully Convolutional Networks)属于利用深度网络进行图片语义分割的开山之作,其主要思想为:
U-Net基于FCN网络提出,能够适应较小的训练集。其采用大量弹性形变的方法对数据进行增强,让模型更好的学习形变不变形。在不同特征融合方式上,U-Net采用通道维度上的拼接融合代替FCN的逐点相加。
SegNet的网络结构借鉴了自编码网络的思想,具有编码器网络和解码器网络。最后通过softmax分类器对每个像素点进行分类。网络在编码器处会执行卷积和最大池化,在解码器部分则会执行上采样和卷积。
本次使用VOC2012数据集,来源于:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html。
数据集中存在20个类别的1个背景类:
Person: person
Animal: bird, cat, cow, dog, horse, sheep
Vehicle: aeroplane, bicycle, boat, bus, car, motorbike, train
Indoor: bottle, chair, dining table, potted plant, sofa, tv/monitor
在Annotations文件夹中,存放有对应图片的标记文件,以XML格式存储。
在Pytorch中,提供训练好的fcn和deeplabv3网络,可以用作图像分割。
需要用到的模块主要是torchvision
,直接pip install torchvision
即可。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL.Image as Image
import torch
from torchvision import transforms
import torchvisio
我们加载torch中训练好的全卷积残差网络fcn_resnet101
,设置预训练。如果是第一次加载需要在网络上下载参数。
由于该网络已经训练好了,所以我们不再进行训练,使用其评估模式eval
。该该模式下,不启用 Batch Normalization 和 Dropout。即在测试过程中保证BN层均值方差不变,在Dropout层不随机舍弃神经元。
# 导入训练好的模块
model=torchvision.models.segmentation.fcn_resnet101(pretrained=True)
model.eval()
然后就需要把我们的图像读取进来啦,这里随机选用一张VOC2012数据。
对图片需要进行预处理:
# 读取图片
image=Image.open(r"F:\VOCdevkit\VOC2012\JPEGImages\2007_002488.jpg")
# 图片预处理
image_transf=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
image_tensor=image_transf(image).unsqueeze(0)
output=model(image_tensor)["out"]
输出的Tensor是结果分类的,为了方便可视化,需要做以下处理:
# 将输出转化为二维图像
outputarg=torch.argmax(output.squeeze(),dim=0).numpy()
# 获取指定颜色编码
label_colors=np.array([(0,0,0),(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128),
(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0),(192,128,0),
(64,0,128),(192,0,128),(64,128,128),(192,128,128),(0,64,0),(128,64,0),
(0,192,0),(128,192,0),(0,64,128)])
图像编码的话,先生成三个维度的初始数据,接着获取输出类别的位置,并在该位置上为三个维度附上颜色值。
# 我们对该图像进行编码,以便可视化
def decode_segmaps(image,label_colors,nc=21):
# 先生成三个通道等大小的影像
r=np.zeros_like(image).astype(np.uint8)
g=np.zeros_like(image).astype(np.uint8)
b=np.zeros_like(image).astype(np.uint8)
for cla in range(0,nc):
idx=(image==cla) # 某一类的位置索引
r[idx]=label_colors[cla,0]
g[idx]=label_colors[cla,1]
b[idx]=label_colors[cla,2]
# 三通道转为图像
rgbimage=np.stack([r,g,b],axis=2)
return rgbimage
最后输出即可
# 进行可视化
outputrgb=decode_segmaps(outputarg,label_colors)
plt.figure(figsize=(12,8))
plt.subplot(1,2,1)
plt.imshow(image)
plt.axis("off")
plt.subplot(1,2,2)
plt.imshow(outputrgb)
plt.axis("off")
plt.subplots_adjust(wspace=0.05)
plt.show()
基于VGG19搭建全卷积语义分割网络。
首先数据有一个标记集:
也有一个原始集:
针对一个图像,在训练阶段我们需要做的事情是:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL.Image as Image
import torch.utils.data as Data
from time import time
import copy
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision
from torch.nn import functional as F
import torch.optim as optim
from torchsummary import summary
全局信息,包括是否使用GPU,以及图像色带。
# 指定设备
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 标识类
colormap=[(0,0,0),(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128),
(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0),(192,128,0),
(64,0,128),(192,0,128),(64,128,128),(192,128,128),(0,64,0),(128,64,0),
(0,192,0),(128,192,0),(0,64,128)]
对于图像数据,我们主要进行以下几个工作:
# 将RGB图像处理为一个类
def img2lab(img,colormap):
cm2lbl=np.zeros(256**3)
# 将每个像素转化为一类数据
for i,cm in enumerate(colormap):
# 这步能够将每个rgb类型的颜色转为单独一类
cm2lbl[((cm[0]*256+cm[1])*256+cm[2])]=i
# 对一张图像进行转化
image=np.array(img,dtype="int64")
ix=((image[:,:,0]*256+image[:,:,1])*256+image[:,:,2]) # 这里会将每个像素点都做映射
# 做完映射后,留下的ix只剩下了两个维度
image2=cm2lbl[ix] # 将对应像素的类映射过去,ix的shape跟img展平后相同
# ix; high,width
# 所以image2是单维度的数据
return image2
# 随机裁剪图像数据
def rand_crop(data,label,high,width):
im_width,im_high=data.size
# 生成图像随机点
left=np.random.randint(0,im_width-width)
top=np.random.randint(0,im_high-high)
right=left+width
bottom=top+high
data=data.crop((left,top,right,bottom))
label=label.crop((left,top,right,bottom))
return data,label
# 图像转换操作
def img_transforms(data,label,high,width,colormap):
data,label=rand_crop(data,label,high,width)
data_tfs=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
data=data_tfs(data)
label=torch.from_numpy(img2lab(label,colormap))
return data,label
VOC2012的数据集路径保存在train.txt
中,我们需要获取该文件,通过np.loadtxt
保存路径信息。
# 读取路径函数
def read_image_path(root=r"F:\VOCdevkit\VOC2012\ImageSets\Segmentation\train.txt"):
image=np.loadtxt(root,dtype=str)
n=len(image)
data,label=[None]*n,[None]*n
for i,fname in enumerate(image):
data[i]=r"F:\VOCdevkit\VOC2012\JPEGImages\%s.jpg"%(fname)
label[i]=r"F:\VOCdevkit\VOC2012\SegmentationClass\%s.png"%(fname)
return data,label
接着我们需要定义一个Dataset
类,继承自torch.utils.data.Dataset
,作为DataLoade
中的数据源。
# 定义一个MyDataset继承于torch.utils.data.Dataset
class MyDataset(Data.Dataset):
def __init__(self,data_root,high,width,imtransform,colormap):
self.data_root=data_root
self.high=high
self.width=width
self.imtransform=imtransform
self.cm=colormap
data_list,label_list=read_image_path(data_root)
self.data_list=self._filter(data_list)
self.label_list=self._filter(label_list)
def _filter(self,images):
# 处理掉不符合尺寸的数据
# 这步需要打开图片,耗时会有点久
imlist=[]
for im in images:
img=Image.open(im)
if img.size[1]>self.high and img.size[0]>self.width:
# 注意此时还是路径
imlist.append(im)
return imlist
def __getitem__(self,idx):
# 这步是核心
img=self.data_list[idx]
lab=self.label_list[idx]
img=Image.open(img)
lab=Image.open(lab).convert("RGB")
img,lab=self.imtransform(img,lab,self.high,self.width,self.cm)
return img,lab
def __len__(self):
return len(self.data_list)
查看数据
# 读取数据
high,width=320,480
voc_train=MyDataset(r"F:\VOCdevkit\VOC2012\ImageSets\Segmentation\train.txt",high,width,img_transforms,colormap)
voc_val=MyDataset(r"F:\VOCdevkit\VOC2012\ImageSets\Segmentation\val.txt",high,width,img_transforms,colormap)
# 创建数据加载器每个batch使用4个图像
train_loader=Data.DataLoader(voc_train,batch_size=4,shuffle=True,pin_memory=True)
val_loader=Data.DataLoader(voc_val,batch_size=4,shuffle=True,pin_memory=True)
# 检查训练数据集的一个batch的样本维度是否正确
for step,(bx,by) in enumerate(train_loader):
if step>0:
break
print("bx.shape",bx.shape)
print("by.shape",by.shape)
print("bx",bx)
print("by",by)
本部分需要做的事情是将Tensor数据转化为图像数据,包括label
和image
的转化。
需要反标准化回去,并且将溢出的浮点抹去
def inv_normalize_img(data):
rgb_mean=np.array([0.485,0.456,0.406])
rgb_std=np.array([0.229,0.224,0.225])
data=data.astype("float32")*rgb_std+rgb_mean
return data.clip(0,1)
将标签转化为RGB图像
def label2img(prelab,colormap):
#从预测到的标签转化为图像,针对一个标签图
h,w=prelab.shape
prelab=prelab.reshape(h*w,-1)
img=np.zeros((h*w,3),dtype="int32")
for ii in range(len(colormap)):
index=np.where(prelab==ii)
img[index,:]=colormap[ii]
return img.reshape(h,w,3)
可视化图像
# 可视化一个batch图像
bx_numpy=bx.data.numpy()
bx_numpy=bx_numpy.transpose(0,2,3,1)
by_numpy=by.data.numpy()
plt.figure(figsize=(16,6))
for i in range(4):
plt.subplot(2,4,i+1)
plt.imshow(inv_normalize_img(bx_numpy[i]))
plt.axis("off")
plt.subplot(2,4,i+5)
plt.imshow(label2img(by_numpy[i],colormap))
plt.axis("off")
plt.subplots_adjust(wspace=0.1,hspace=0.1)
plt.show()
使用训练好的VGG19网络作为backbone,定义语义分割网络FCN8S。其核心在于:
FCN8S会在第五个最大池化层进行反卷积,得到大小为w/16
的特征,融合将其加上第四个最大池化后的数据后进行处理,再次反卷积得到w/8
的特征。最后通过分类器,将特征维度转换为类别数量,判断每个像素点在每个特征维度上的概率,即可实现图像分割。
# 定义语义分割网络FCN-8S
class FCN8S(nn.Module):
def __init__(self,num_class):
'''
:param num_class: 训练数据的类别
'''
super(FCN8S, self).__init__()
self.num_class=num_class
model_vgg19=torchvision.models.vgg19(pretrained=True)
self.backbone=model_vgg19.features
# 需要做的事情是反卷积、特征融合
# 传到这里的数据已经成了[batch,512,10,15]
self.relu=nn.ReLU(inplace=True)
self.deconv1=nn.ConvTranspose2d(512,512,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)
self.bn1=nn.BatchNorm2d(512)
# 512->256
self.deconv2=nn.ConvTranspose2d(512,256,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)
self.bn2=nn.BatchNorm2d(256)
#256->128
self.deconv3=nn.ConvTranspose2d(256,128,kernel_size=3,stride=2,padding=1,dilation=1,output_padding=1)
self.bn3=nn.BatchNorm2d(128)
# 128->64
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
# 64->32
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
# 提取32维特征FCN32S
self.classifier=nn.Conv2d(32,num_class,kernel_size=1)
# VGG19中maxpool2所在的层
self.layers={"4":"maxpool_1",
"9":"maxpool_2",
"18":"maxpool_3",
"27":"maxpool_4",
"36":"maxpool_5"}
def forward(self,x):
output={}
for name,layer in self.backbone._modules.items():
x=layer(x)
if name in self.layers:
# 留下这层信息
output[self.layers[name]]=x
# 获取各层的信息
x5=output["maxpool_5"] # (b,512,x.H/32,x.W/32)
x4=output["maxpool_4"] # (b,512,x.H/16,x.W/16)
x3=output["maxpool_3"] # (b,256,x.H/8,x.W/8)
# 转置卷积
score=self.relu(self.deconv1(x5)) # 512 16 16
# 加上16s的信息
score=self.bn1(score+x4)
# 再反卷积
score=self.relu(self.deconv2(score)) # 256 8 8
# 加上8s的信息
score=self.bn2(score+x3)
# 一步一步慢慢回去
score=self.bn3(self.relu(self.deconv3(score))) # 128 4 4
score=self.bn4(self.relu(self.deconv4(score))) # 64 2 2
score=self.bn5(self.relu(self.deconv5(score))) # 32 1 1
# 最后将这32维转成输入维度(rgb:3,gray:1)
score=self.classifier(score)
return score # b,n_class,1,1
正常去训练就好,注意这里没有使用残差网络,所以可以保留最好的参数。
# 网络训练
def train_model(model,criterion,optimizer,traindataloader,valdataloader,num_epoch=25):
since=time()
best_model_wts=copy.deepcopy(model.state_dict())
best_loss=1e10
train_loss_all=[]
train_acc_all=[]
val_loss_all=[]
val_acc_all=[]
# 训练num_epoch次
for epoch in range(num_epoch):
print("Epoch {}/{}".format(epoch,num_epoch-1))
print("-"*10)
train_loss=0.0
train_num=0
val_loss=0.0
val_num=0
# 训练阶段
model.train()
for step,(bx,by) in enumerate(traindataloader):
optimizer.zero_grad()
bx=bx.float().to(device)
by=by.long().to(device)
out=model(bx)
# 柔和21个类做一个softmax后再做log
out=F.log_softmax(out,dim=1)
pre_lab=torch.argmax(out,1)
loss=criterion(out,by)
loss.backward()
optimizer.step()
train_loss+=loss.item()*len(by)
train_num+=len(by)
# 计算训练集上的精度
train_loss_all.append(train_loss/train_num)
print("{} Train Loss: {:.4f}".format(epoch,train_loss_all[-1]))
# 进入评估阶段
model.eval()
for step,(bx,by) in enumerate(valdataloader):
bx,by=bx.float().to(device),by.long().to(device)
out=model(bx)
out=F.log_softmax(out,dim=1)
pre_lab=torch.argmax(out,1)
loss=criterion(out,by)
val_loss+=loss.item()*len(by)
val_num+=len(by)
# 计算损失和精度
val_loss_all.append(val_loss/val_num)
print("{} Val Loss: {:.4f}".format(epoch,val_loss_all[-1]))
# 保留最好的参数
if val_loss_all[-1]<best_loss:
best_loss=val_loss_all[-1]
best_model_wts=copy.deepcopy(model.state_dict())
time_use=time()-since
print("Train and val complete in {:.0f}m {:.0f}s".format(time_use//60,time_use%60))
# 输出训练信息
train_process=pd.DataFrame(
data={"epoch":range(num_epoch),
"train_loss_all":train_loss_all,
"val_loss_all":val_loss_all}
)
# 加载最好的模型
model.load_state_dict(best_model_wts)
return model,train_process
# 定义损失函数和优化器
lr=0.0003
criterion=nn.NLLLoss()
optimizer=optim.Adam(fcn8s.parameters(),lr=lr,weight_decay=1e-4)
# 迭代训练
fcn8s,train_process=train_model(fcn8s,criterion,optimizer,train_loader,val_loader,num_epoch=5)
torch.save(fcn8s,"fcn8s.pkl")
# 可视化训练过程
plt.figure(figsize=(10,6))
plt.plot(train_process.epoch,train_process.train_loss_all,"ro-",label="Train Loss")
plt.plot(train_process.epoch,train_process.val_loss_all,"bs-",label="Val Loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.show()
查看在验证集上的效果
# 从验证集中获取一个batch的数据
for step,(bx,by) in enumerate(val_loader):
if step>0:
break
fcn8s.eval()
bx=bx.float().to(device)
by=by.long().to(device)
out=fcn8s(bx)
# 拿出来结果后需要在维度上做一个log_softmax
out=F.log_softmax(out,dim=1)
# 判别最可能的概率
pre_lab=torch.argmax(out,1)
bx_numpy=bx.cpu().data.numpy()
bx_numpy=bx_numpy.transpose(0,2,3,1)
by_numpy=by.cpu().data.numpy()
pre_lab_numpy=pre_lab.cpu().numpy()
for i in range(4):
plt.subplot(3,4,i+1)
plt.imshow(inv_normalize_img(bx_numpy[i]))
plt.axis("off")
plt.subplot(3,4,i+5)
plt.imshow(label2img(by_numpy[i],colormap))
plt.axis("off")
plt.imshow(label2img(pre_lab_numpy[i],colormap))
plt.axis("off")
plt.subplots_adjust(wspace=0.05,hspace=0.05)
plt.show()