tensorflow2.7实现unet网络结构

之前的文章我觉得并没有如实按着论文中的结构来写模型,所以重新写了一下,并在合适的位置标记了代码中的对应位置。

tensorflow2.7实现unet网络结构_第1张图片

 

unet.py

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from tensorflow.keras.layers import (Input,Conv2D,UpSampling2D,
                                     Concatenate,MaxPooling2D,
                                     )
from tensorflow.keras.models import Model
import tensorflow as tf


def unet():
    """ set up unet model """
    inputs = Input(shape=(572,572,1),name='input')
    conv1_1 = Conv2D(64,3,activation='relu',name='conv1_1')(inputs)
    conv1_2 = Conv2D(64,3,activation='relu',name='conv1_2')(conv1_1)

    max_pool1 = MaxPooling2D(name='maxpool1')(conv1_2) # maxpooling2D默认stride=filer size
    conv2_1 = Conv2D(128,3,activation='relu',name='conv2_1')(max_pool1)
    conv2_2 = Conv2D(128,3,activation='relu',name='conv2_2')(conv2_1)

    max_pool2 = MaxPooling2D(name='maxpool2')(conv2_2)
    conv3_1 = Conv2D(256,3,activation='relu',name='conv3_1')(max_pool2)
    conv3_2 = Conv2D(256,3,activation='relu',name='conv3_2')(conv3_1)

    maxpool3 = MaxPooling2D(name='maxpool3')(conv3_2)
    conv4_1 = Conv2D(512,3,activation='relu',name='conv4_1')(maxpool3)
    conv4_2 = Conv2D(512,3,activation='relu',name='conv4_2')(conv4_1)

    maxpool4 = MaxPooling2D(name='maxpool4')(conv4_2)
    conv5_1 = Conv2D(1024,3,activation='relu',name='conv5_1')(maxpool4)
    conv5_2 = Conv2D(1024,3,activation='relu',name='conv5_2')(conv5_1)

    up5 = UpSampling2D(name='up5')(conv5_2)
    up5_conv = Conv2D(512,2,padding='same',name='up5_conv')(up5)
    conv4_feature = tf.image.resize(conv4_2,(up5_conv.shape[1],up5_conv.shape[2]),
                                    tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                                    name='conv4_feature')

    concat1 = Concatenate(name='concat1')([up5_conv,conv4_feature])
    conv6_1 = Conv2D(512,3,activation='relu',name='conv6_1')(concat1)
    conv6_2 = Conv2D(512,3,activation='relu',name='conv6_2')(conv6_1)

    up6 = UpSampling2D(name='up6')(conv6_2)
    up6_conv = Conv2D(256,2,padding='same',name='up6_conv')(up6)
    conv3_feature = tf.image.resize(conv3_2,(up6_conv.shape[1],up6_conv.shape[2]),
                                    tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                                    name='conv3_feature')

    concat2 = Concatenate(name='concat2')([up6_conv,conv3_feature])
    conv7_1 = Conv2D(256,3,activation='relu',name='conv7_1')(concat2)
    conv7_2 = Conv2D(256,3,activation='relu',name='conv7_2')(conv7_1)

    up7 = UpSampling2D(name='up7')(conv7_2)
    up7_conv = Conv2D(128,2,padding='same',name='up7_conv')(up7)
    conv2_feature = tf.image.resize(conv2_2,(up7_conv.shape[1],up7_conv.shape[2]),
                                    tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                                    name='conv2_feature')

    concat3 = Concatenate(name='concat3')([up7_conv,conv2_feature])
    conv8_1 = Conv2D(128,3,activation='relu',name='conv8_1')(concat3)
    conv8_2 = Conv2D(128,3,activation='relu',name='conv8_2')(conv8_1)

    up8 = UpSampling2D(name='up8')(conv8_2)
    up8_conv = Conv2D(64,2,padding='same',name='up8_conv')(up8)
    conv1_feature = tf.image.resize(conv1_2,(up8_conv.shape[1],up8_conv.shape[2]),
                                    tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                                    name='conv1_feature')

    concat4 = Concatenate(name='concat4')([up8_conv,conv1_feature])
    conv9_1 = Conv2D(64,3,activation='relu',name='conv9_1')(concat4)
    conv9_2 = Conv2D(64,3,activation='relu',name='conv9_2')(conv9_1)

    out = Conv2D(2,1,activation='sigmoid',name='out')(conv9_2)

    model = Model(inputs=inputs,outputs=out,name='mytf2_unet')

    return model

if __name__ == '__main__':
    model = unet()
    model.summary()

输出:

D:\program\miniconda\envs\py38_tf2\python.exe E:/paper/nets/unet.py
Model: "mytf2_unet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input (InputLayer)             [(None, 572, 572, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv1_1 (Conv2D)               (None, 570, 570, 64  640         ['input[0][0]']                  
                                )                                                                 
                                                                                                  
 conv1_2 (Conv2D)               (None, 568, 568, 64  36928       ['conv1_1[0][0]']                
                                )                                                                 
                                                                                                  
 maxpool1 (MaxPooling2D)        (None, 284, 284, 64  0           ['conv1_2[0][0]']                
                                )                                                                 
                                                                                                  
 conv2_1 (Conv2D)               (None, 282, 282, 12  73856       ['maxpool1[0][0]']               
                                8)                                                                
                                                                                                  
 conv2_2 (Conv2D)               (None, 280, 280, 12  147584      ['conv2_1[0][0]']                
                                8)                                                                
                                                                                                  
 maxpool2 (MaxPooling2D)        (None, 140, 140, 12  0           ['conv2_2[0][0]']                
                                8)                                                                
                                                                                                  
 conv3_1 (Conv2D)               (None, 138, 138, 25  295168      ['maxpool2[0][0]']               
                                6)                                                                
                                                                                                  
 conv3_2 (Conv2D)               (None, 136, 136, 25  590080      ['conv3_1[0][0]']                
                                6)                                                                
                                                                                                  
 maxpool3 (MaxPooling2D)        (None, 68, 68, 256)  0           ['conv3_2[0][0]']                
                                                                                                  
 conv4_1 (Conv2D)               (None, 66, 66, 512)  1180160     ['maxpool3[0][0]']               
                                                                                                  
 conv4_2 (Conv2D)               (None, 64, 64, 512)  2359808     ['conv4_1[0][0]']                
                                                                                                  
 maxpool4 (MaxPooling2D)        (None, 32, 32, 512)  0           ['conv4_2[0][0]']                
                                                                                                  
 conv5_1 (Conv2D)               (None, 30, 30, 1024  4719616     ['maxpool4[0][0]']               
                                )                                                                 
                                                                                                  
 conv5_2 (Conv2D)               (None, 28, 28, 1024  9438208     ['conv5_1[0][0]']                
                                )                                                                 
                                                                                                  
 up5 (UpSampling2D)             (None, 56, 56, 1024  0           ['conv5_2[0][0]']                
                                )                                                                 
                                                                                                  
 up5_conv (Conv2D)              (None, 56, 56, 512)  2097664     ['up5[0][0]']                    
                                                                                                  
 tf.image.resize (TFOpLambda)   (None, 56, 56, 512)  0           ['conv4_2[0][0]']                
                                                                                                  
 concat1 (Concatenate)          (None, 56, 56, 1024  0           ['up5_conv[0][0]',               
                                )                                 'tf.image.resize[0][0]']        
                                                                                                  
 conv6_1 (Conv2D)               (None, 54, 54, 512)  4719104     ['concat1[0][0]']                
                                                                                                  
 conv6_2 (Conv2D)               (None, 52, 52, 512)  2359808     ['conv6_1[0][0]']                
                                                                                                  
 up6 (UpSampling2D)             (None, 104, 104, 51  0           ['conv6_2[0][0]']                
                                2)                                                                
                                                                                                  
 up6_conv (Conv2D)              (None, 104, 104, 25  524544      ['up6[0][0]']                    
                                6)                                                                
                                                                                                  
 tf.image.resize_1 (TFOpLambda)  (None, 104, 104, 25  0          ['conv3_2[0][0]']                
                                6)                                                                
                                                                                                  
 concat2 (Concatenate)          (None, 104, 104, 51  0           ['up6_conv[0][0]',               
                                2)                                'tf.image.resize_1[0][0]']      
                                                                                                  
 conv7_1 (Conv2D)               (None, 102, 102, 25  1179904     ['concat2[0][0]']                
                                6)                                                                
                                                                                                  
 conv7_2 (Conv2D)               (None, 100, 100, 25  590080      ['conv7_1[0][0]']                
                                6)                                                                
                                                                                                  
 up7 (UpSampling2D)             (None, 200, 200, 25  0           ['conv7_2[0][0]']                
                                6)                                                                
                                                                                                  
 up7_conv (Conv2D)              (None, 200, 200, 12  131200      ['up7[0][0]']                    
                                8)                                                                
                                                                                                  
 tf.image.resize_2 (TFOpLambda)  (None, 200, 200, 12  0          ['conv2_2[0][0]']                
                                8)                                                                
                                                                                                  
 concat3 (Concatenate)          (None, 200, 200, 25  0           ['up7_conv[0][0]',               
                                6)                                'tf.image.resize_2[0][0]']      
                                                                                                  
 conv8_1 (Conv2D)               (None, 198, 198, 12  295040      ['concat3[0][0]']                
                                8)                                                                
                                                                                                  
 conv8_2 (Conv2D)               (None, 196, 196, 12  147584      ['conv8_1[0][0]']                
                                8)                                                                
                                                                                                  
 up8 (UpSampling2D)             (None, 392, 392, 12  0           ['conv8_2[0][0]']                
                                8)                                                                
                                                                                                  
 up8_conv (Conv2D)              (None, 392, 392, 64  32832       ['up8[0][0]']                    
                                )                                                                 
                                                                                                  
 tf.image.resize_3 (TFOpLambda)  (None, 392, 392, 64  0          ['conv1_2[0][0]']                
                                )                                                                 
                                                                                                  
 concat4 (Concatenate)          (None, 392, 392, 12  0           ['up8_conv[0][0]',               
                                8)                                'tf.image.resize_3[0][0]']      
                                                                                                  
 conv9_1 (Conv2D)               (None, 390, 390, 64  73792       ['concat4[0][0]']                
                                )                                                                 
                                                                                                  
 conv9_2 (Conv2D)               (None, 388, 388, 64  36928       ['conv9_1[0][0]']                
                                )                                                                 
                                                                                                  
 out (Conv2D)                   (None, 388, 388, 2)  130         ['conv9_2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 31,030,658
Trainable params: 31,030,658
Non-trainable params: 0
__________________________________________________________________________________________________

Process finished with exit code 0

你可能感兴趣的:(深度学习,tensorflow2,深度学习)