u2net实现视频图像分割(从原理到实践)

一、U2net简单介绍

1、U2net网络结构:

u2net实现视频图像分割(从原理到实践)_第1张图片

整个网络成对称U型结构,使用的是经典的编解码结构,在每一个Sup内部又是U形结构,采用的是深监督的方式,有效结合浅层和深层的语义信息。进行了5次下采样和5次上采样,上采样的方式通过torch.nn.functional.interpolate()函数实现,下采样通过torch.nn.MaxPool2d() 步长为2的最大平均池化实现。在每一个En_x中使用RSU模块,RSU模块的结构如下:

u2net实现视频图像分割(从原理到实践)_第2张图片

 

u2net实现视频图像分割(从原理到实践)_第3张图片

RSU模块的作用是获得在不同阶段的多尺度特征(L指的是在编码器中的层数,Cin和Cout分别代表输入通道核输出通道,M表示RSU内部层中的通道数),该结构主要由3部分构成:

(1)输入的卷积层,将输入的特征图转为和输出相同的通道数的中间映射用于局部特征提取

(2)一种高度为L的对称式编解码结构,将中间映射作为输入,提取和学习多尺度的语义信息

(3)用于融合局部特征和所尺度特征的残差结构

在U2Net中同时使用了add和Concate

2、损失函数:

u2net实现视频图像分割(从原理到实践)_第4张图片

因为有6个Sup,所以有6个损失函数,每一个Sup的损失使用的是标准交叉熵损失函数

u2net实现视频图像分割(从原理到实践)_第5张图片

二、代码部分:

网络部分对照着图看还是比较清晰的,其余大部分文件添加了注释,方便自己二次回顾

1、U2net.py

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class REBNCONV(nn.Module):                                                          #CBL
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self,x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout


def _upsample_like(src,tar):

    # src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
    src = F.interpolate(src,size=tar.shape[2:],mode='bilinear',align_corners=True)     #     https://www.cnblogs.com/wanghui-garcia/p/11399034.html
    # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    return src


### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):                          #En_1   

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)              #CBR1
        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)              #CBR2

        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)           
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)             

        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)                                      #dialation=2

        hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))

        hx6dup = _upsample_like(hx6d,hx5)
        hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))

        hx5dup = _upsample_like(hx5d,hx4)
        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))

        hx4dup = _upsample_like(hx4d,hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))

        hx3dup = _upsample_like(hx3d,hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))

        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)


        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)

### U^2-Net small ###
class U2NETP(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NETP,self).__init__()

        self.stage1 = RSU7(in_ch,16,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,16,64)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(64,16,64)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(64,16,64)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(64,16,64)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(64,16,64)

        # decoder
        self.stage5d = RSU4F(128,16,64)
        self.stage4d = RSU4(128,16,64)
        self.stage3d = RSU5(128,16,64)
        self.stage2d = RSU6(128,16,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(64,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #decoder
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)

2、data_loader.py

# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color                                            #scikit-image是基于scipy的一款图像处理包,它将图片作为numpy数组进行处理
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image

#==========================dataset load==========================
class RescaleT(object):                                                                      #此处等比缩放原始图像位指定的输出大

    def __init__(self,output_size):
        assert isinstance(output_size,(int,tuple))
        self.output_size = output_size                                                        #获得输出的图片的大小

    def __call__(self,sample):
        imidx, image, label = sample['imidx'], sample['image'],sample['label']                #获取到图片的索引、图片名和标签

        h, w = image.shape[:2]                                                                #获取图片的形状
    
        if isinstance(self.output_size,int):
            if h > w:
                new_h, new_w = self.output_size*h/w,self.output_size                          #根据输出图片的大小重新分配宽和高
            else:
                new_h, new_w = self.output_size,self.output_size*w/h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
        # img = transform.resize(image,(new_h,new_w),mode='constant')
        # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)

        img = transform.resize(image,(self.output_size,self.output_size),mode='constant')      #此处等比缩放原始图像位指定的输出大小
        lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)  #等比缩放标签图像    #skimage.transform import resize

        return {'imidx':imidx, 'image':img,'label':lbl}

class Rescale(object):                                                                    #重新缩放至指定大小

    def __init__(self,output_size):
        assert isinstance(output_size,(int,tuple))
        self.output_size = output_size

    def __call__(self,sample):                                                             #使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。

        imidx, image, label = sample['imidx'], sample['image'],sample['label']

        if random.random() >= 0.5:
            image = image[::-1]
            label = label[::-1]

        h, w = image.shape[:2]

        if isinstance(self.output_size,int):
            if h > w:
                new_h, new_w = self.output_size*h/w,self.output_size
            else:
                new_h, new_w = self.output_size,self.output_size*w/h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
        img = transform.resize(image,(new_h,new_w),mode='constant')
        lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)

        return {'imidx':imidx, 'image':img,'label':lbl}

class RandomCrop(object):                                                                   #返回经过随机裁剪后的图像和标签,指定输出大小

    def __init__(self,output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
    def __call__(self,sample):
        imidx, image, label = sample['imidx'], sample['image'], sample['label']

        if random.random() >= 0.5:
            image = image[::-1]
            label = label[::-1]

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h, left: left + new_w]                                 #对原始的图片进行随机裁
        label = label[top: top + new_h, left: left + new_w]                                 #对原始标签随机裁剪

        return {'imidx':imidx,'image':image, 'label':label}                                 #返回经过随机裁剪后的图像和标签

class ToTensor(object):                                                             #对图像和标签归一化
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):

        imidx, image, label = sample['imidx'], sample['image'], sample['label']

        tmpImg = np.zeros((image.shape[0],image.shape[1],3))
        tmpLbl = np.zeros(label.shape)

        image = image/np.max(image)                                                         #归一化图片
        if(np.max(label)<1e-6):
            label = label
        else:
            label = label/np.max(label)

        if image.shape[2]==1:
            tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
            tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
            tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
        else:
            tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
            tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
            tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225

        tmpLbl[:,:,0] = label[:,:,0]

        # change the r,g,b to b,r,g from [0,255] to [0,1]
        #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        tmpImg = tmpImg.transpose((2, 0, 1))
        tmpLbl = label.transpose((2, 0, 1))

        return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}

class ToTensorLab(object):
    """Convert ndarrays in sample to Tensors."""
    def __init__(self,flag=0):
        self.flag = flag

    def __call__(self, sample):

        imidx, image, label =sample['imidx'], sample['image'], sample['label']

        tmpLbl = np.zeros(label.shape)

        if(np.max(label)<1e-6):
            label = label
        else:
            label = label/np.max(label)

        # change the color space
        if self.flag == 2: # with rgb and Lab colors
            tmpImg = np.zeros((image.shape[0],image.shape[1],6))
            tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
            if image.shape[2]==1:
                tmpImgt[:,:,0] = image[:,:,0]
                tmpImgt[:,:,1] = image[:,:,0]
                tmpImgt[:,:,2] = image[:,:,0]
            else:
                tmpImgt = image
            tmpImgtl = color.rgb2lab(tmpImgt)

            # nomalize image to range [0,1]
            tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
            tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
            tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
            tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
            tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
            tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))

            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
            tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
            tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
            tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])

        elif self.flag == 1: #with Lab color
            tmpImg = np.zeros((image.shape[0],image.shape[1],3))

            if image.shape[2]==1:
                tmpImg[:,:,0] = image[:,:,0]
                tmpImg[:,:,1] = image[:,:,0]
                tmpImg[:,:,2] = image[:,:,0]
            else:
                tmpImg = image

            tmpImg = color.rgb2lab(tmpImg)

            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])

        else: # with rgb color
            tmpImg = np.zeros((image.shape[0],image.shape[1],3))
            image = image/np.max(image)
            if image.shape[2]==1:
                tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
            else:
                tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
                tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225

        tmpLbl[:,:,0] = label[:,:,0]

        # change the r,g,b to b,r,g from [0,255] to [0,1]
        #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        tmpImg = tmpImg.transpose((2, 0, 1))
        tmpLbl = label.transpose((2, 0, 1))

        return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}

class SalObjDataset(Dataset):                                                                   #返回归一化后的图片索引,图片,标签图片
    def __init__(self,img_name_list,lbl_name_list,transform=None):
        # self.root_dir = root_dir
        # self.image_name_list = glob.glob(image_dir+'*.png')
        # self.label_name_list = glob.glob(label_dir+'*.png')
        self.image_name_list = img_name_list                                                     #获取到所有的图片名绝对路径
        self.label_name_list = lbl_name_list                                                     #获取到所有的标签绝对路径
        self.transform = transform                                                               #transform包括裁剪缩放转tensor

    def __len__(self):
        return len(self.image_name_list)

    def __getitem__(self,idx):

        # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
        # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])

        image = io.imread(self.image_name_list[idx])                                             #通过每张的绝对路径读取到每一张图片
        # print(type(image))                                                                       # BGR格式的图片

        # import cv2
        # cv2.imshow("",cv2.cvtColor(np.uint8(image),cv2.COLOR_BGR2RGB))
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()

        # print(image.shape)                                                                       #(375, 500, 3)大小不固定,单都是3通道
        # print("=======================================================")
        imname = self.image_name_list[idx]
        imidx = np.array([idx])                                                                  #图片的索引转化numpy的数组

        if(0==len(self.label_name_list)):                                                        #如果没有标签则创建一个0标签
            label_3 = np.zeros(image.shape)
        else:                                                                                    #如果有标签则获取对应的标签
            label_3 = io.imread(self.label_name_list[idx])

        label = np.zeros(label_3.shape[0:2])                                                     #将标签数据用不同维度的0表示

        if(3==len(label_3.shape)):
            label = label_3[:,:,0]
        elif(2==len(label_3.shape)):
            label = label_3

        if(3==len(image.shape) and 2==len(label.shape)):
            label = label[:,:,np.newaxis]                                                        #np.newaxis的作用就是在这一位置增加一个一维,这一位置指的是np.newaxis所在的位置
        elif(2==len(image.shape) and 2==len(label.shape)):
            image = image[:,:,np.newaxis]
            label = label[:,:,np.newaxis]

        sample = {'imidx':imidx, 'image':image, 'label':label}

        if self.transform:
            sample = self.transform(sample)                                                      #对图像transform

        return sample

3、u2net_train.py

import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

import numpy as np
import glob
import os

from my_U2_Net.data_loader import RescaleT,RandomCrop,ToTensorLab,SalObjDataset,ToTensor,Rescale
from my_U2_Net.model import U2NET,U2NETP

# ------- 1. define loss function --------

bce_loss = nn.BCELoss(reduction='mean')

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):

    loss0 = bce_loss(d0,labels_v)
    loss1 = bce_loss(d1,labels_v)
    loss2 = bce_loss(d2,labels_v)
    loss3 = bce_loss(d3,labels_v)
    loss4 = bce_loss(d4,labels_v)
    loss5 = bce_loss(d5,labels_v)
    loss6 = bce_loss(d6,labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6              #在网路结构中共有6个sup所以有6个损失函数
    print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item(),loss6.item()))

    return loss0, loss

def main():
    # ------- 2. set the directory of training dataset --------
    model_name = 'u2net' #'u2netp'                                     #选取的网络模型的种类



    # params_path = os.path.join("../saved_models", model_name,model_name.pth)


    data_dir = r'F:\PASCAL_VOC\VOCdevkit\VOC2007\my_segmentations'     #包含训练图片和分割标签的上一级目录
    tra_image_dir = 'JPEGImages'                                       #图片所在的目录名
    tra_label_dir = 'SegmentationClass'                                #标签所在的目录名

    image_ext = '.jpg'
    label_ext = '.png'                                                 #标签的后缀名

    model_dir = './saved_models/' + model_name +'/'                    #保存的参数模型所在的文件夹名
    params_path = model_dir +model_name  +".pth"                       #保存的参数参数的相对路径
    epoch_num = 100000                                                 #训练的总轮次
    batch_size_train = 2                                               #训练的批次
    batch_size_val = 1                                                 #验证的批次

    train_num = 0
    val_num = 0

    # tra_img_name_list = glob.glob(data_dir + "\\" + tra_image_dir + '*')
    tra_img_name_list = glob.glob(os.path.join(data_dir, tra_image_dir,'*'))   #训练的图片所在的路径,glob.glob()将图片的绝对路径保存到一个列表
    print("hahah")
    # print(tra_img_name_list)                                                   #包含所有训练图片绝对路径的列表
    # print("-------------------------------------------------------------------------------------------------")
    tra_lbl_name_list = []
    for img_path in tra_img_name_list:                                         #遍历每一个图片的绝对路径
        img_name = img_path.split("\\")[-1]                                    #取出图片的名字,如:003000.jpg
        aaa = img_name.split(".")
        bbb = aaa[0:-1]                                                        #['000032']
        # print(bbb)
        #去除后缀的图片名
        imidx = bbb[0]                                                         #000032
        # print(imidx)                                                         #000032
        # print(len(bbb))                                                      #1
        for i in range(1,len(bbb)):
            print(i)
            imidx = imidx + "." + bbb[i]
            print(imidx,"**********")

        tra_lbl_name_list.append(data_dir+ "\\"  + tra_label_dir+ "\\"  + imidx + label_ext)
    print(tra_lbl_name_list)                                                   #标签的绝对路径的列表,和训练图片的绝对路径一一对应

    print("---")
    print("train images: ", len(tra_img_name_list))                            #422
    print("train labels: ", len(tra_lbl_name_list))
    print("---")

    train_num = len(tra_img_name_list)                                         #训练的图片的总数

    salobj_dataset = SalObjDataset(
        img_name_list=tra_img_name_list,
        lbl_name_list=tra_lbl_name_list,
        transform=transforms.Compose([
            RescaleT(320),
            RandomCrop(288),
            ToTensorLab(flag=0)]))     #RescaleT(320)等比缩放为指定的大小,RandomCrop(288)随机裁剪为指定的大小
    salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)

    # ------- 3. define model --------
    # define the net
    if(model_name=='u2net'):
        net = U2NET(3, 1)                                            #指定输入通道核输出通道的大小
    elif(model_name=='u2netp'):                                      #网络实例化
        net = U2NETP(3,1)

    if torch.cuda.is_available():
        net.cuda()                                                   #网络转移至GPU

    if os.path.exists(params_path):                                  #加载训练好的模型参数
        net.load_state_dict(torch.load(params_path))
    else:
        print("No parameters!")


    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

    # ------- 5. training process --------
    print("---start training...")
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 0
    save_frq = 2000              #save the model every 2000 iterations

    for epoch in range(0, epoch_num):
        net.train()                                                  #训练模式

        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1

            inputs, labels = data['image'], data['label']            #获取图片和标签

            inputs = inputs.type(torch.FloatTensor)                  #转为tensor类型
            labels = labels.type(torch.FloatTensor)

            # wrap them in Variable
            if torch.cuda.is_available():                            #转移到cuda
                inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                            requires_grad=False)
            else:
                inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

            # y zero the parameter gradients
            optimizer.zero_grad()                                    #优化器清空梯度

            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)               #网络输出
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v) #6个sup作损失

            loss.backward()                                          #反向求导更新梯度
            optimizer.step()                                         #下一步

            # # print statistics
            running_loss += loss.item()                              #总损失
            running_tar_loss += loss2.item()

            # delete temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss

            print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
            epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))

            if ite_num % save_frq == 0:

                # torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
                running_loss = 0.0
                running_tar_loss = 0.0
                net.train()  # resume train
                ite_num4val = 0
        # torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
        torch.save(net.state_dict(),params_path)
if __name__ == "__main__":
    main()

4、u2net_test.py

import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from my_U2_Net.data_loader import RescaleT
from my_U2_Net.data_loader import ToTensor
from my_U2_Net.data_loader import ToTensorLab
from my_U2_Net.data_loader import SalObjDataset                        #加载数据用(返回图片的索引,归一化后的图片,标签)

from my_U2_Net.model import U2NET # full size version 173.6 MB         #导入两个网络
from my_U2_Net.model import U2NETP # small version u2net 4.7 MB

# normalize the predicted SOD probability map
def normPRED(d):                                           #归一化
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred.squeeze()                               #删除单维度
    predict_np = predict.cpu().data.numpy()                #转移到CPU上

    im = Image.fromarray(predict_np*255).convert('RGB')    #转为PIL,从归一化的图片恢复到正常0到255之间
    img_name = image_name.split("\\")[-1]                  #取出后缀类型
    # print(image_name)
    # print(img_name)
    image = io.imread(image_name)                          #io.imread读出图片格式是uint8(unsigned int);value是numpy array;图像数据是以RGB的格式进行存储的,通道值默认范围0-255
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    # pb_np = np.array(imo)                                  #多余的

    aaa = img_name.split(".")                              #图片名字被切分为一个列表
    bbb = aaa[0:-1]                                        #取出图片名称的前缀
    # print(aaa)                                             #['5', 'jpg']
    # print(bbb)                                             #['5']
    # print("---------------------------------------------")
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')                           #保存图片到指定路径

def main():

    # --------- 1. get image path and name ---------
    model_name='u2net'#u2netp                              #保存的模型的名称


    image_dir = './test_data/test_images/'                 #将要预测的图片所在的文件夹路径
    prediction_dir = './test_data/' + model_name + '_results/'#预测结果的保存的文件夹路径
    # model_dir = '../saved_models/'+ model_name + '/' + model_name + '.pth'
    # model_dir = r"../saved_models/u2net/u2net_bce_itr_422_train_3.743319_tar_0.546805.pth"
    model_dir = "\my_U2_Net\saved_models\u2net\u2net.pth"          #模型参数的路径

    img_name_list = glob.glob(image_dir + '*')             #图片文件夹下的所有数据(携带路径)
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)     #加载数据

    # --------- 3. model define ---------
    if(model_name=='u2net'):                               #分辨使用的是哪一个模型参数
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)
    net.load_state_dict(torch.load(model_dir))             #加载训练好的模型
    if torch.cuda.is_available():
        net.cuda()                                         #网络转移至GPU
    net.eval()                                             #测评模式

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split("/")[-1])   #test_images\5.jpg
        # print(data_test)                                   #'imidx': tensor([[0]], dtype=torch.int32), 'image': tensor([[[[ 1.4051,  ...'label': tensor([[[[0., 0., 0.,  ...,
        inputs_test = data_test['image']                   #测试的是图片
        inputs_test = inputs_test.type(torch.FloatTensor)  #转为浮点型

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
            #Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性,
            # tensor不能反向传播,variable可以反向传播。它会逐渐地生成计算图。
            # 这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,
            # 一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7 = net(inputs_test)             #将图片传入网络

        # normalization
        pred = d1[:,0,:,:]
        pred = normPRED(pred)                               #对预测的结果做归一化

        # save results to test_results folder
        save_output(img_name_list[i_test],pred,prediction_dir)  #原始图片名、预测结果,预测图片的保存目录   #save_output保存预测的输出值

        del d1,d2,d3,d4,d5,d6,d7                            #del 用于删除对象。在 Python,一切都是对象,因此 del 关键字可用于删除变量、列表或列表片段等。

if __name__ == "__main__":
    main()

注:本次训练使用的是VOC2007的数据进行训练的,加载了预训练模型

输出结果为:

原图

u2net实现视频图像分割(从原理到实践)_第6张图片

得到掩码图:

u2net实现视频图像分割(从原理到实践)_第7张图片

提取目标:

crop.py

# -*- coding: utf-8 -*-

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os

def crop(img_file, mask_file):
    name, *_ = img_file.split(".")                                        #将后缀划分开
    img_array = np.array(Image.open(img_file))                            #真实图片的PIL转到numpy类型
    mask = np.array(Image.open(mask_file))                                #打开掩码图片
    res = np.concatenate((img_array, mask[:, :, [0]]), -1)               #将原图和掩码进行数组拼接
    img = Image.fromarray(res.astype('uint8'), mode='RGBA')              #数组转为PIL格式
    img.show()
    return img


if __name__ == "__main__":


    model = "u2net"
    # model = "u2netp"

    img_root = "test_data/test_images"                                    #真实图片的保存路径
    mask_root = "test_data/{}_results".format(model)                      #掩码的根目录
    crop_root = "test_data/{}_crops".format(model)                        #裁剪图片的保存的目录
    os.makedirs(crop_root, mode=0o775, exist_ok=True)                     #创建保存图片的路径

    # name:想要创建的目录名,modemode:要为目录设置的权限数字模式,
    # 默认的模式为 0o777 (八进制)。exist_ok:是否在目录存在时触发异常。
    # 如果exist_ok为False(默认值),则在目标目录已存在的情况下触发FileExistsError异常;
    # 如果exist_ok为True,则在目标目录已存在的情况下不会触发FileExistsError异常。

    for img_file in os.listdir(img_root):                                #遍历所有的源图片
        print("crop image {}".format(img_file))
        name, *_ = img_file.split(".")                                   #划分出图片的名字和后缀
        res = crop(
            img_file=os.path.join(img_root,  img_file),
            mask_file=os.path.join(mask_root, name + ".png")
        )                                                                #调用自定义的crop()函数
        res.save(os.path.join(crop_root, name + "_crop.png"))            #保存图片的到指定的保存路径

得到图片如下:

5、video_for_video.py

为了实现视频分割,整合了上边的几个文件,由于本代码还未进行优化。。。但勉强测试下视频分割还是可以的,后续必要是再进行优化。

# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color                                            #scikit-image是基于scipy的一款图像处理包,它将图片作为numpy数组进行处理
import numpy as np
import random
import math
import matplotlib.pyplot as plt

from torchvision import transforms, utils
from PIL import Image
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import glob
from my_U2_Net.model import U2NET,U2NETP                    #导入两个网络
import cv2


capture = cv2.VideoCapture(0)
class RescaleT(object):                                                                      #此处等比缩放原始图像位指定的输出大

    def __init__(self,output_size):
        assert isinstance(output_size,(int,tuple))
        self.output_size = output_size                                                        #获得输出的图片的大小

    def __call__(self,sample):
        imidx, image, label,frame = sample['imidx'], sample['image'],sample['label'],sample['frame']                 #获取到图片的索引、图片名和标签

        h, w = image.shape[:2]                                                                #获取图片的形状

        if isinstance(self.output_size,int):
            if h > w:
                new_h, new_w = self.output_size*h/w,self.output_size                          #根据输出图片的大小重新分配宽和高
            else:
                new_h, new_w = self.output_size,self.output_size*w/h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
        # img = transform.resize(image,(new_h,new_w),mode='constant')
        # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)

        img = transform.resize(image,(self.output_size,self.output_size),mode='constant')      #此处等比缩放原始图像位指定的输出大小
        lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)  #等比缩放标签图像    #skimage.transform import resize

        return {'imidx':imidx, 'image':img,'label':lbl,"frame":frame}

class Rescale(object):                                                                    #重新缩放至指定大小

    def __init__(self,output_size):
        assert isinstance(output_size,(int,tuple))
        self.output_size = output_size

    def __call__(self,sample):                                                             #使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。

        imidx, image, label,frame = sample['imidx'], sample['image'],sample['label'],sample['frame']

        if random.random() >= 0.5:
            image = image[::-1]
            label = label[::-1]

        h, w = image.shape[:2]

        if isinstance(self.output_size,int):
            if h > w:
                new_h, new_w = self.output_size*h/w,self.output_size
            else:
                new_h, new_w = self.output_size,self.output_size*w/h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
        img = transform.resize(image,(new_h,new_w),mode='constant')
        lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)

        return {'imidx':imidx, 'image':img,'label':lbl,"frame":frame}

class RandomCrop(object):                                                                   #返回经过随机裁剪后的图像和标签,指定输出大小

    def __init__(self,output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
    def __call__(self,sample):
        imidx, image, label = sample['imidx'], sample['image'], sample['label']

        if random.random() >= 0.5:
            image = image[::-1]
            label = label[::-1]

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h, left: left + new_w]                                 #对原始的图片进行随机裁
        label = label[top: top + new_h, left: left + new_w]                                 #对原始标签随机裁剪

        return {'imidx':imidx,'image':image, 'label':label}                                 #返回经过随机裁剪后的图像和标签

class ToTensor(object):                                                             #对图像和标签归一化
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):

        imidx, image, label = sample['imidx'], sample['image'], sample['label']

        tmpImg = np.zeros((image.shape[0],image.shape[1],3))
        tmpLbl = np.zeros(label.shape)

        image = image/np.max(image)                                                         #归一化图片
        if(np.max(label)<1e-6):
            label = label
        else:
            label = label/np.max(label)

        if image.shape[2]==1:
            tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
            tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
            tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
        else:
            tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
            tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
            tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225

        tmpLbl[:,:,0] = label[:,:,0]

        # change the r,g,b to b,r,g from [0,255] to [0,1]
        #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        tmpImg = tmpImg.transpose((2, 0, 1))
        tmpLbl = label.transpose((2, 0, 1))

        return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}

class ToTensorLab(object):
    """Convert ndarrays in sample to Tensors."""
    def __init__(self,flag=0):
        self.flag = flag

    def __call__(self, sample):

        imidx, image, label,frame =sample['imidx'], sample['image'], sample['label'], sample['frame']

        tmpLbl = np.zeros(label.shape)

        if(np.max(label)<1e-6):
            label = label
        else:
            label = label/np.max(label)

        # change the color space
        if self.flag == 2: # with rgb and Lab colors
            tmpImg = np.zeros((image.shape[0],image.shape[1],6))
            tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
            if image.shape[2]==1:
                tmpImgt[:,:,0] = image[:,:,0]
                tmpImgt[:,:,1] = image[:,:,0]
                tmpImgt[:,:,2] = image[:,:,0]
            else:
                tmpImgt = image
            tmpImgtl = color.rgb2lab(tmpImgt)

            # nomalize image to range [0,1]
            tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
            tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
            tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
            tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
            tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
            tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))

            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
            tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
            tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
            tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])

        elif self.flag == 1: #with Lab color
            tmpImg = np.zeros((image.shape[0],image.shape[1],3))

            if image.shape[2]==1:
                tmpImg[:,:,0] = image[:,:,0]
                tmpImg[:,:,1] = image[:,:,0]
                tmpImg[:,:,2] = image[:,:,0]
            else:
                tmpImg = image

            tmpImg = color.rgb2lab(tmpImg)

            # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))

            tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
            tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
            tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])

        else: # with rgb color
            tmpImg = np.zeros((image.shape[0],image.shape[1],3))
            image = image/np.max(image)
            if image.shape[2]==1:
                tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
            else:
                tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
                tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
                tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225

        tmpLbl[:,:,0] = label[:,:,0]

        # change the r,g,b to b,r,g from [0,255] to [0,1]
        #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        tmpImg = tmpImg.transpose((2, 0, 1))
        tmpLbl = label.transpose((2, 0, 1))

        return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl),"frame":frame}

class SalObjDataset(Dataset):                                                                   #返回归一化后的图片索引,图片,标签图片
    def __init__(self,img_name_list,lbl_name_list,transform=None):
        # self.root_dir = root_dir
        # self.image_name_list = glob.glob(image_dir+'*.png')
        # self.label_name_list = glob.glob(label_dir+'*.png')
        self.image_name_list = img_name_list                                                     #获取到所有的图片名绝对路径
        self.label_name_list = lbl_name_list                                                     #获取到所有的标签绝对路径
        self.transform = transform                                                               #transform包括裁剪缩放转tensor

    def __len__(self):
        return len(self.image_name_list)

    def __getitem__(self,idx):
        # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
        # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
        # image = io.imread(self.image_name_list[idx])                                             #通过每张的绝对路径读取到每一张图片
        # print(type(image))                                                                       #


        while True:
            ref, frame = capture.read()  # 读取某一帧
            image = frame
            # image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # 格式转变,BGRtoRGB

            print(image.shape)          #(480, 640, 3)
            print("=======================================================")
            # imname = self.image_name_list[idx]
            imidx = np.array([idx])                                                                  #图片的索引转化numpy的数组

            if(0==len(self.label_name_list)):                                                        #如果没有标签则创建一个0标签
                label_3 = np.zeros(image.shape)
            else:                                                                                    #如果有标签则获取对应的标签
                label_3 = io.imread(self.label_name_list[idx])

            label = np.zeros(label_3.shape[0:2])                                                     #将标签数据用不同维度的0表示

            if(3==len(label_3.shape)):
                label = label_3[:,:,0]
            elif(2==len(label_3.shape)):
                label = label_3

            if(3==len(image.shape) and 2==len(label.shape)):
                label = label[:,:,np.newaxis]                                                        #np.newaxis的作用就是在这一位置增加一个一维,这一位置指的是np.newaxis所在的位置
            elif(2==len(image.shape) and 2==len(label.shape)):
                image = image[:,:,np.newaxis]
                label = label[:,:,np.newaxis]

            sample = {'imidx':imidx, 'image':image, 'label':label,"frame":frame}

            if self.transform:
                sample = self.transform(sample)                                                      #对图像transform
            return sample

def main():

    model_name = 'u2net'#u2netp                              #保存的模型的名称
    model_dir = r"\my_U2_Net\saved_models\u2net\u2net.pth"          #模型参数的路径
    img_name_list = [i for i in range(10000)]
    test_salobj_dataset = SalObjDataset(img_name_list=img_name_list,
                                        lbl_name_list=[],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )

    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)     #加载数据


    if(model_name=='u2net'):                               #分辨使用的是哪一个模型参数
        print("...load U2NET---173.6 MB")
        net = U2NET(3,1)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,1)
    net.load_state_dict(torch.load(model_dir))             #加载训练好的模型
    if torch.cuda.is_available():
        net.cuda()                                         #网络转移至GPU
    net.eval()                                             #测评模式

    for i_test, data_test in enumerate(test_salobj_dataloader):

        inputs_test = data_test['image']                   #测试的是图片
        inputs_test = inputs_test.type(torch.FloatTensor)  #转为浮点型

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
            #Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性,
            # tensor不能反向传播,variable可以反向传播。它会逐渐地生成计算图。
            # 这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,
            # 一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7 = net(inputs_test)             #将图片传入网络

        pred = d1[:,0,:,:]
        pred = (pred-torch.min(pred))/(torch.max(pred)-torch.min(pred))  #对预测的结果做归一化


        predict = pred.squeeze()  # 删除单维度
        predict_np = predict.cpu().data.numpy()  # 转移到CPU上
        im = Image.fromarray(predict_np * 255).convert('RGB')  # 转为PIL,从归一化的图片恢复到正常0到255之间

        imo = im.resize((640, 480), resample=Image.BILINEAR)  # 得到的掩码!!!!!!!!

        # img_array = np.asarray(Image.fromarray(np.uint8(data_test['image'])))

        img_array = np.uint8(data_test["frame"][0])
        # print(data_test)
        # print(data_test["frame"][0])
        # print(data_test["frame"][0].shape)
        # cv2.imshow("", np.uint8(data_test["frame"][0]))
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()


        mask = np.asarray(Image.fromarray(np.uint8(imo)))

        # cv2.imshow("", np.uint8(mask))
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
        # print(img_array.shape)
        # print("ccccccccccccccccccc")
        # res = np.concatenate((img_array, mask[:, :, [0]]), -1)  # 将原图和掩码进行数组拼接
        # img = cv2.cvtColor(res, cv2.COLOR_RGB2BGRA)
        # img = Image.fromarray(img.astype('uint8'), mode='RGBA')
        # img.show()
        # b, g, r, a = cv2.split(img)
        # img = cv2.merge([a,r,b,g,])

        img = Image.fromarray(np.uint8(img_array * (mask / 255)))

        cv2.imshow("",np.uint8(img))
        if cv2.waitKey(1) & 0xFF == ord('q'): break

        del d1,d2,d3,d4,d5,d6,d7                            #del 用于删除对象。在 Python,一切都是对象,因此 del 关键字可用于删除变量、列表或列表片段等。


if __name__ == "__main__":
    main()                        #调用

展示效果如下(原本是视频版此处只放效果张图):

如图为展示桌面键盘,背景已经被分离了

u2net实现视频图像分割(从原理到实践)_第8张图片

参考资料:

https://arxiv.org/pdf/2005.09007.pdf

https://github.com/NathanUA/U-2-Net

https://zhuanlan.zhihu.com/p/44958351

你可能感兴趣的:(pytorch,神经网络,深度学习,图像处理)