之前的文章我觉得并没有如实按着论文中的结构来写模型,所以重新写了一下,并在合适的位置标记了代码中的对应位置。
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.keras.layers import (Input,Conv2D,UpSampling2D,
Concatenate,MaxPooling2D,
)
from tensorflow.keras.models import Model
import tensorflow as tf
def unet():
""" set up unet model """
inputs = Input(shape=(572,572,1),name='input')
conv1_1 = Conv2D(64,3,activation='relu',name='conv1_1')(inputs)
conv1_2 = Conv2D(64,3,activation='relu',name='conv1_2')(conv1_1)
max_pool1 = MaxPooling2D(name='maxpool1')(conv1_2) # maxpooling2D默认stride=filer size
conv2_1 = Conv2D(128,3,activation='relu',name='conv2_1')(max_pool1)
conv2_2 = Conv2D(128,3,activation='relu',name='conv2_2')(conv2_1)
max_pool2 = MaxPooling2D(name='maxpool2')(conv2_2)
conv3_1 = Conv2D(256,3,activation='relu',name='conv3_1')(max_pool2)
conv3_2 = Conv2D(256,3,activation='relu',name='conv3_2')(conv3_1)
maxpool3 = MaxPooling2D(name='maxpool3')(conv3_2)
conv4_1 = Conv2D(512,3,activation='relu',name='conv4_1')(maxpool3)
conv4_2 = Conv2D(512,3,activation='relu',name='conv4_2')(conv4_1)
maxpool4 = MaxPooling2D(name='maxpool4')(conv4_2)
conv5_1 = Conv2D(1024,3,activation='relu',name='conv5_1')(maxpool4)
conv5_2 = Conv2D(1024,3,activation='relu',name='conv5_2')(conv5_1)
up5 = UpSampling2D(name='up5')(conv5_2)
up5_conv = Conv2D(512,2,padding='same',name='up5_conv')(up5)
conv4_feature = tf.image.resize(conv4_2,(up5_conv.shape[1],up5_conv.shape[2]),
tf.image.ResizeMethod.NEAREST_NEIGHBOR,
name='conv4_feature')
concat1 = Concatenate(name='concat1')([up5_conv,conv4_feature])
conv6_1 = Conv2D(512,3,activation='relu',name='conv6_1')(concat1)
conv6_2 = Conv2D(512,3,activation='relu',name='conv6_2')(conv6_1)
up6 = UpSampling2D(name='up6')(conv6_2)
up6_conv = Conv2D(256,2,padding='same',name='up6_conv')(up6)
conv3_feature = tf.image.resize(conv3_2,(up6_conv.shape[1],up6_conv.shape[2]),
tf.image.ResizeMethod.NEAREST_NEIGHBOR,
name='conv3_feature')
concat2 = Concatenate(name='concat2')([up6_conv,conv3_feature])
conv7_1 = Conv2D(256,3,activation='relu',name='conv7_1')(concat2)
conv7_2 = Conv2D(256,3,activation='relu',name='conv7_2')(conv7_1)
up7 = UpSampling2D(name='up7')(conv7_2)
up7_conv = Conv2D(128,2,padding='same',name='up7_conv')(up7)
conv2_feature = tf.image.resize(conv2_2,(up7_conv.shape[1],up7_conv.shape[2]),
tf.image.ResizeMethod.NEAREST_NEIGHBOR,
name='conv2_feature')
concat3 = Concatenate(name='concat3')([up7_conv,conv2_feature])
conv8_1 = Conv2D(128,3,activation='relu',name='conv8_1')(concat3)
conv8_2 = Conv2D(128,3,activation='relu',name='conv8_2')(conv8_1)
up8 = UpSampling2D(name='up8')(conv8_2)
up8_conv = Conv2D(64,2,padding='same',name='up8_conv')(up8)
conv1_feature = tf.image.resize(conv1_2,(up8_conv.shape[1],up8_conv.shape[2]),
tf.image.ResizeMethod.NEAREST_NEIGHBOR,
name='conv1_feature')
concat4 = Concatenate(name='concat4')([up8_conv,conv1_feature])
conv9_1 = Conv2D(64,3,activation='relu',name='conv9_1')(concat4)
conv9_2 = Conv2D(64,3,activation='relu',name='conv9_2')(conv9_1)
out = Conv2D(2,1,activation='sigmoid',name='out')(conv9_2)
model = Model(inputs=inputs,outputs=out,name='mytf2_unet')
return model
if __name__ == '__main__':
model = unet()
model.summary()
D:\program\miniconda\envs\py38_tf2\python.exe E:/paper/nets/unet.py
Model: "mytf2_unet"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input (InputLayer) [(None, 572, 572, 1 0 []
)]
conv1_1 (Conv2D) (None, 570, 570, 64 640 ['input[0][0]']
)
conv1_2 (Conv2D) (None, 568, 568, 64 36928 ['conv1_1[0][0]']
)
maxpool1 (MaxPooling2D) (None, 284, 284, 64 0 ['conv1_2[0][0]']
)
conv2_1 (Conv2D) (None, 282, 282, 12 73856 ['maxpool1[0][0]']
8)
conv2_2 (Conv2D) (None, 280, 280, 12 147584 ['conv2_1[0][0]']
8)
maxpool2 (MaxPooling2D) (None, 140, 140, 12 0 ['conv2_2[0][0]']
8)
conv3_1 (Conv2D) (None, 138, 138, 25 295168 ['maxpool2[0][0]']
6)
conv3_2 (Conv2D) (None, 136, 136, 25 590080 ['conv3_1[0][0]']
6)
maxpool3 (MaxPooling2D) (None, 68, 68, 256) 0 ['conv3_2[0][0]']
conv4_1 (Conv2D) (None, 66, 66, 512) 1180160 ['maxpool3[0][0]']
conv4_2 (Conv2D) (None, 64, 64, 512) 2359808 ['conv4_1[0][0]']
maxpool4 (MaxPooling2D) (None, 32, 32, 512) 0 ['conv4_2[0][0]']
conv5_1 (Conv2D) (None, 30, 30, 1024 4719616 ['maxpool4[0][0]']
)
conv5_2 (Conv2D) (None, 28, 28, 1024 9438208 ['conv5_1[0][0]']
)
up5 (UpSampling2D) (None, 56, 56, 1024 0 ['conv5_2[0][0]']
)
up5_conv (Conv2D) (None, 56, 56, 512) 2097664 ['up5[0][0]']
tf.image.resize (TFOpLambda) (None, 56, 56, 512) 0 ['conv4_2[0][0]']
concat1 (Concatenate) (None, 56, 56, 1024 0 ['up5_conv[0][0]',
) 'tf.image.resize[0][0]']
conv6_1 (Conv2D) (None, 54, 54, 512) 4719104 ['concat1[0][0]']
conv6_2 (Conv2D) (None, 52, 52, 512) 2359808 ['conv6_1[0][0]']
up6 (UpSampling2D) (None, 104, 104, 51 0 ['conv6_2[0][0]']
2)
up6_conv (Conv2D) (None, 104, 104, 25 524544 ['up6[0][0]']
6)
tf.image.resize_1 (TFOpLambda) (None, 104, 104, 25 0 ['conv3_2[0][0]']
6)
concat2 (Concatenate) (None, 104, 104, 51 0 ['up6_conv[0][0]',
2) 'tf.image.resize_1[0][0]']
conv7_1 (Conv2D) (None, 102, 102, 25 1179904 ['concat2[0][0]']
6)
conv7_2 (Conv2D) (None, 100, 100, 25 590080 ['conv7_1[0][0]']
6)
up7 (UpSampling2D) (None, 200, 200, 25 0 ['conv7_2[0][0]']
6)
up7_conv (Conv2D) (None, 200, 200, 12 131200 ['up7[0][0]']
8)
tf.image.resize_2 (TFOpLambda) (None, 200, 200, 12 0 ['conv2_2[0][0]']
8)
concat3 (Concatenate) (None, 200, 200, 25 0 ['up7_conv[0][0]',
6) 'tf.image.resize_2[0][0]']
conv8_1 (Conv2D) (None, 198, 198, 12 295040 ['concat3[0][0]']
8)
conv8_2 (Conv2D) (None, 196, 196, 12 147584 ['conv8_1[0][0]']
8)
up8 (UpSampling2D) (None, 392, 392, 12 0 ['conv8_2[0][0]']
8)
up8_conv (Conv2D) (None, 392, 392, 64 32832 ['up8[0][0]']
)
tf.image.resize_3 (TFOpLambda) (None, 392, 392, 64 0 ['conv1_2[0][0]']
)
concat4 (Concatenate) (None, 392, 392, 12 0 ['up8_conv[0][0]',
8) 'tf.image.resize_3[0][0]']
conv9_1 (Conv2D) (None, 390, 390, 64 73792 ['concat4[0][0]']
)
conv9_2 (Conv2D) (None, 388, 388, 64 36928 ['conv9_1[0][0]']
)
out (Conv2D) (None, 388, 388, 2) 130 ['conv9_2[0][0]']
==================================================================================================
Total params: 31,030,658
Trainable params: 31,030,658
Non-trainable params: 0
__________________________________________________________________________________________________
Process finished with exit code 0