主要参考资料:Unet的网站和论文。
U-Net最早用作生物图像的分割,后来在目标检测、图像转换,以及Tone Mapping ,Reverse Tone Mapping很多地方都有应用。它的一个特点是早期的卷积层结果和最后几层的结果采用级联的形式作为新的神经网络层。我觉得它的过程很类似图像金字塔和图像重建的过程,前面的下采样,提取出信息,后面进行重建,区别在于这里不是像拉普拉斯金字塔重建那样将图像复原,这里则是生成了具有“新特性”的图像。之所以叫它U-Net,是因为他看起来像个U形,如果考虑中间层次的级联,更像一把琴。
从图中可以看出,U-Net主要由Conv+ReLu,maxpool,up-conv,conv 1x1 几个部分构成,那么我们首先在tensorflow里面将这几个部分函数化。
- conv+ReLU
def conv_relu_layer(net,numfilters,name):
network = tf.layers.conv2d(net,
activation=tf.nn.relu,
filters= numfilters,
kernel_size=(3,3),
padding='Valid',
name= "{}_conv_relu".format(name))
return network
- maxpool
def maxpool(net,name):
network = tf.layers.max_pooling2d(net,
pool_size= (2,2),
strides = (2,2),
padding = 'valid',
name = "{}_maxpool".format(name))
return network
- up_conv
def up_conv(net,numfilters,name):
network = tf.layers.conv2d_transpose(net,
filters = numfilters,
kernel_size= (2,2),
strides= (2,2),
padding= 'valid',
activation= tf.nn.relu,
name = "{}_up_conv".format(name))
return network
- copy_crop
def copy_crop(skip_connect,net):
skip_connect_shape = skip_connect.get_shape()
net_shape = net.get_shape()
print(net_shape[1])
size = [-1,net_shape[1].value,net_shape[2].value,-1]
skip_connect_crop = tf.slice(skip_connect,[0,0,0,0],size)
concat = tf.concat([skip_connect_crop,net],axis=3)
return concat
- conv1x1
def conv1x1(net,numfilters,name):
return tf.layers.conv2d(net,filters=numfilters,strides=(1,1),kernel_size=(1,1),name = "{}_conv1x1".format(name),padding='SAME')
#define input data
input = tf.placeholder(dtype=tf.float32,shape = (64,572,572,3))
#define downsample path
network = conv_relu_layer(input,numfilters=64,name='lev1_layer1')
skip_con1 = conv_relu_layer(network,numfilters=64,name='lev1_layer2')
network = maxpool(skip_con1,'lev2_layer1')
network = conv_relu_layer(network,128,'lev2_layer2')
skip_con2 = conv_relu_layer(network,128,'lev2_layer3')
network = maxpool(skip_con2,'lev3_layer1')
network = conv_relu_layer(network,256,'lev3_layer1')
skip_con3 = conv_relu_layer(network,256,'lev3_layer2')
network = maxpool(skip_con3,'lev4_layer1')
network = conv_relu_layer(network,512,'lev4_layer2')
skip_con4 = conv_relu_layer(network,512,'lev4_layer3')
network = maxpool(skip_con4,'lev5_layer1')
network = conv_relu_layer(network,1024,'lev5_layer2')
network = conv_relu_layer(network,1024,'lev5_layer3')
#define upsample path
network = up_conv(network,512,'lev6_layer1')
network = copy_crop(skip_con4,network)
network = conv_relu_layer(network,numfilters=512,name='lev6_layer2')
network = conv_relu_layer(network,numfilters=512,name='lev6_layer3')
network = up_conv(network,256,name='lev7_layer1')
network = copy_crop(skip_con3,network)
network = conv_relu_layer(network,256,name='lev7_layer2')
network = conv_relu_layer(network,256,'lev7_layer3')
network = up_conv(network,128,name='lev8_layer1')
network = copy_crop(skip_con2,network)
network = conv_relu_layer(network,128,name='lev8_layer2')
network = conv_relu_layer(network,128,'lev8_layer3')
network = up_conv(network,64,name='lev9_layer1')
network = copy_crop(skip_con1,network)
network = conv_relu_layer(network,64,name='lev9_layer2')
network = conv_relu_layer(network,64,name='lev9_layer3')
network = conv1x1(network,2,name='lev9_layer4')
利用tensorboard可以得到如下的网络架构图。