U^2Net是我之前使用过的一个图像分割的网络,由于效果比较好,所以对其印象比较深刻,同样为了学习,当时用TensorFlow2.x重构了这个网络,虽说现在发有点晚,但为了防止自己忘记,还是把它记录了下来。相关文件链接如下:
论文地址
官方代码(pytorch)
详细解读
pip install tensorflow-gpu
import tensorflow.keras as k
import tensorflow as tf
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Conv2D,MaxPool2D,BatchNormalization,ReLU,UpSampling2D
RSU解耦股相当于一个小型的unet结构,前半部分通过一系列下采样来提取特征,后半部分通过上采样和concat的方式达到特征融合。需要注意的是:这里的连续4个RSU模块其实可以写的更简化,这里我就直接全部展开来写了,代码如下:
#基本卷积块
class REBNCONV(Model):
def __init__(self,out_ch=3,dirate=1):
super(REBNCONV, self).__init__()
self.conv=Sequential()
self.conv.add(Conv2D(out_ch,kernel_size=(3,3),strides=(1,1),padding="same",dilation_rate=dirate))
self.conv.add(BatchNormalization())
self.conv.add(ReLU())
def call(self, inputs, training=None, mask=None):
x=self.conv(inputs)
return x
#第一个RSU结构
class RSU7(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU7, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv2=REBNCONV(mid_ch,dirate=1)
self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv3=REBNCONV(mid_ch,dirate=1)
self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv4=REBNCONV(mid_ch,dirate=1)
self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv5=REBNCONV(mid_ch,dirate=1)
self.pool5=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv6=REBNCONV(mid_ch,dirate=1)
self.rebnconv7=REBNCONV(mid_ch,dirate=2)
self.rebnconv6d=REBNCONV(mid_ch,dirate=1)
self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
h0=self.rebnconv0(inputs)
hx1=self.rebnconv1(h0)
hx=self.pool1(hx1)
hx2=self.rebnconv2(hx)
hx=self.pool2(hx2)
hx3=self.rebnconv3(hx)
hx=self.pool3(hx3)
hx4=self.rebnconv4(hx)
hx=self.pool4(hx4)
hx5=self.rebnconv5(hx)
hx=self.pool5(hx5)
hx6=self.rebnconv6(hx)
hx7=self.rebnconv7(hx6)
hx6d=self.rebnconv6d(tf.concat((hx7,hx6),axis=3))
hx6d_up=UpSampling2D((2,2),interpolation="bilinear")(hx6d) #上采样
hx5d=self.rebnconv5d(tf.concat((hx6d_up,hx5),axis=3))
hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)
hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
return hx1d+h0
#第二个RSU模块
class RSU6(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU6, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.pool1=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv2=REBNCONV(mid_ch,dirate=1)
self.pool2=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv3=REBNCONV(mid_ch,dirate=1)
self.pool3=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv4=REBNCONV(mid_ch,dirate=1)
self.pool4=MaxPool2D(pool_size=(2,2),strides=(2,2))
self.rebnconv5=REBNCONV(mid_ch,dirate=1)
self.rebnconv6=REBNCONV(mid_ch,dirate=2)
self.rebnconv5d=REBNCONV(mid_ch,dirate=1)
self.rebnconv4d=REBNCONV(mid_ch,dirate=1)
self.rebnconv3d=REBNCONV(mid_ch,dirate=1)
self.rebnconv2d=REBNCONV(mid_ch,dirate=1)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
hx0=self.rebnconv0(inputs)
hx1=self.rebnconv1(hx0)
hx=self.pool1(hx1)
hx2=self.rebnconv2(hx)
hx=self.pool2(hx2)
hx3=self.rebnconv3(hx)
hx=self.pool3(hx3)
hx4=self.rebnconv4(hx)
hx=self.pool4(hx4)
hx5=self.rebnconv5(hx)
hx6=self.rebnconv6(hx5)
hx5d=self.rebnconv5d(tf.concat((hx6,hx5),axis=3))
hx5d_up= UpSampling2D((2,2),interpolation="bilinear")(hx5d)
# print(hx5d_up.shape)
# print(hx4.shape)
hx4d=self.rebnconv4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.rebnconv3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.rebnconv2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.rebnconv1d(tf.concat((hx2d_up,hx1),axis=3))
return hx1d+hx0
#第三个RSU模块
class RSU5(Model):
def __init__(self, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconv0 = REBNCONV(out_ch, dirate=1)
self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
self.pool3 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv4 = REBNCONV(mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(out_ch, dirate=1)
def call(self, inputs, training=None, mask=None):
hx0 = self.rebnconv0(inputs)
hx1 = self.rebnconv1(hx0)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(tf.concat((hx5, hx4), axis=3))
hx4d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx4d)
hx3d = self.rebnconv3d(tf.concat((hx4d_up, hx3), axis=3))
hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)
hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)
hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
return hx1d + hx0
#第四个RSU模块
class RSU4(Model):
def __init__(self, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconv0 = REBNCONV(out_ch, dirate=1)
self.rebnconv1 = REBNCONV(mid_ch, dirate=1)
self.pool1 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv2 = REBNCONV(mid_ch, dirate=1)
self.pool2 = MaxPool2D(pool_size=(2, 2), strides=(2, 2))
self.rebnconv3 = REBNCONV(mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(out_ch, dirate=1)
def call(self, inputs, training=None, mask=None):
hx0 = self.rebnconv0(inputs)
hx1 = self.rebnconv1(hx0)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx)
hx3d = self.rebnconv3d(tf.concat((hx4, hx3), axis=3))
hx3d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx3d)
hx2d = self.rebnconv2d(tf.concat((hx3d_up, hx2), axis=3))
hx2d_up = UpSampling2D((2, 2), interpolation="bilinear")(hx2d)
hx1d = self.rebnconv1d(tf.concat((hx2d_up, hx1), axis=3))
return hx1d + hx0
在经过4个RSU模块后,特征图的分辨率变得很小(输入分辨率为320320的情况下,这里的分辨率为1818),如果继续下采样的话,会导致更多信息的丢失,所以作者在这里用扩展卷积替换掉了concat和上采样的操作,也就是说扩展模块中的所有中间层的特征图和输入进来的特征图有着相同的分辨率。代码如下:
#扩展模块
class RSU4F(Model):
def __init__(self,mid_ch=12,out_ch=3):
super(RSU4F, self).__init__()
self.rebnconv0=REBNCONV(out_ch,dirate=1)
self.rebnconv1=REBNCONV(mid_ch,dirate=1)
self.rebnconv2=REBNCONV(mid_ch,dirate=2)
self.rebnconv3=REBNCONV(mid_ch,dirate=4)
self.rebnconv4=REBNCONV(mid_ch,dirate=8)
self.rebnconv3d=REBNCONV(mid_ch,dirate=4)
self.rebnconv2d=REBNCONV(mid_ch,dirate=2)
self.rebnconv1d=REBNCONV(out_ch,dirate=1)
def call(self, inputs, training=None, mask=None):
hx0=self.rebnconv0(inputs)
hx1=self.rebnconv1(hx0)
hx2=self.rebnconv2(hx1)
hx3=self.rebnconv3(hx2)
hx4=self.rebnconv4(hx3)
hx3d=self.rebnconv3d(tf.concat((hx4,hx3),axis=3))
hx2d=self.rebnconv2d(tf.concat((hx3d,hx2),axis=3))
hx1d=self.rebnconv1d(tf.concat((hx2d,hx1),axis=3))
return hx1d+hx0
这里需要注意的是,在得到每层的概率图时,所使用的激活函数均为sigmoid,整体的网络实现代码如下:
#U^2Net
class U2NET(Model):
def __init__(self,out_ch=1):
super(U2NET, self).__init__()
#encode
self.stage1=RSU7(32,64)
self.pool1_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#144*144
self.stage2=RSU6(32,128)
self.pool2_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#72*72
self.stage3=RSU5(64,256)
self.pool3_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#36*36
self.stage4=RSU4(128,512)
self.pool4_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#18*18
self.stage5=RSU4F(256,512)
self.pool5_1=MaxPool2D(pool_size=(2,2),strides=(2,2))#9*9
self.stage6=RSU4F(256,512)
#decode
self.stage5d=RSU4F(256,512)
self.stage4d=RSU4(128,256)
self.stage3d=RSU5(64,128)
self.stage2d=RSU6(32,64)
self.stage1d=RSU7(16,64)
#每个层的输出
self.side1=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side2=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side3=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side4=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side5=Conv2D(out_ch,kernel_size=(3,3),padding="same")
self.side6=Conv2D(out_ch,kernel_size=(3,3),padding="same")
#最终输出
self.outconv=Conv2D(out_ch,kernel_size=(1,1))
def call(self, inputs, training=None, mask=None):
hx1=self.stage1(inputs)
hx=self.pool1_1(hx1)
hx2=self.stage2(hx)
hx=self.pool2_1(hx2)
hx3=self.stage3(hx)
hx=self.pool3_1(hx3)
hx4=self.stage4(hx)
hx=self.pool4_1(hx4)
hx5=self.stage5(hx)
hx=self.pool5_1(hx5)
hx6=self.stage6(hx)
hx6_up=UpSampling2D((2,2),interpolation="bilinear")(hx6)
#decode
hx5d=self.stage5d(tf.concat((hx6_up,hx5),axis=3))
hx5d_up=UpSampling2D((2,2),interpolation="bilinear")(hx5d)
hx4d=self.stage4d(tf.concat((hx5d_up,hx4),axis=3))
hx4d_up=UpSampling2D((2,2),interpolation="bilinear")(hx4d)
hx3d=self.stage3d(tf.concat((hx4d_up,hx3),axis=3))
hx3d_up=UpSampling2D((2,2),interpolation="bilinear")(hx3d)
hx2d=self.stage2d(tf.concat((hx3d_up,hx2),axis=3))
hx2d_up=UpSampling2D((2,2),interpolation="bilinear")(hx2d)
hx1d=self.stage1d(tf.concat((hx2d_up,hx1),axis=3))
# side out
d1=self.side1(hx1d)
d2=self.side2(hx2d)
d2=UpSampling2D((2,2),interpolation="bilinear")(d2)
d3=self.side3(hx3d)
d3=UpSampling2D((4,4),interpolation="bilinear")(d3)
d4=self.side4(hx4d)
d4=UpSampling2D((8,8),interpolation="bilinear")(d4)
d5=self.side5(hx5d)
d5=UpSampling2D((16,16),interpolation="bilinear")(d5)
d6=self.side6(hx6)
d6=UpSampling2D((32,32),interpolation="bilinear")(d6)
out=self.outconv(tf.concat((d1,d2,d3,d4,d5,d6),axis=3))
sig=k.activations.sigmoid #定义激活函数
return sig(out),sig(d1),sig(d2),sig(d3),sig(d4),sig(d5),sig(d6)
以上便是本篇文章的全部内容,如需训练代码,或发现文章中有错误,欢迎在评论区留言。