DenseNet中的dense_block模块的实现

运行环境:tensorflow2.1+python3.7

代码如下:

#denseNet121的实现
import tensorflow as tf
from tensorflow.keras.layers import Dense,Conv2D,BatchNormalization,AveragePooling2D,Flatten
from tensorflow.keras.layers import MaxPooling2D,Activation,Concatenate,GlobalAveragePooling2D
from tensorflow.keras import Input,Model

def dense_block(x,blocks):
    for i in range(blocks):
        x=conv2d_block(x,32)
    return x
def conv2d_block(x,grow_rate):
    #在tf中,axis是从0开始计算的,3表示channel层
    bn_axis=3
    
    x1=BatchNormalization(axis=bn_axis,epsilon=1.0001e-5)(x)
    x1=Activation('relu')(x1)
    x1=Conv2D(grow_rate*4,(1,1),padding='same')(x1)
    
    x1=BatchNormalization(axis=bn_axis,epsilon=1.0001e-5)(x1)
    x1=Activation('relu')(x1)
    x1=Conv2D(grow_rate,(3,3),padding='same')(x1)
    
    x=Concatenate()([x,x1])
    return x

def transition_block(x,reduction):
    
    bn_axis=3
    x=BatchNormalization(axis=bn_axis,epsilon=1.001e-5)(x)
    x=Activation('relu')(x)
    
    x=Conv2D(int(x.shape[bn_axis]*reduction),(1,1),padding='same')(x)
    x=AveragePooling2D(strides=(2,2))(x)
    
    return x

def denseNet(inputs,blocks):
    x=Conv2D(24,(7,7),padding='same',strides=2)(inputs)
    x=MaxPooling2D(pool_size=(3,3),strides=2,padding='same')(x)
    for i,block in enumerate(blocks):
        x=dense_block(x,block)
        if i!=len(blocks)-1:
            x=transition_block(x,0.5)
    x=GlobalAveragePooling2D()(x)
    x=Flatten()(x)
    x=Dense(1000,activation='relu')(x)
    return x

inputs=Input([224,224,3])
blocks=[6,12,24,16]
#x=dense_block(inputs,6,"dense")
#outputs=transition_block(x,0.5,"dense")


outputs=denseNet(inputs,blocks)
model=Model(inputs,outputs,name="densenet")
model.summary()

运行结果:

"""
Model: "densenet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 112, 112, 24) 3552        input_1[0][0]                    
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 56, 56, 24)   0           conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 56, 56, 24)   96          max_pooling2d[0][0]              
__________________________________________________________________________________________________
activation (Activation)         (None, 56, 56, 24)   0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 56, 56, 128)  3200        activation[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 56, 56, 128)  512         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 56, 56, 128)  0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 56, 56, 32)   36896       activation_1[0][0]               
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 56, 56, 56)   0           max_pooling2d[0][0]              
                                                                 conv2d_2[0][0]                   
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 56, 56, 56)   224         concatenate[0][0]                
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 56, 56, 56)   0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 56, 56, 128)  7296        activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 56, 56, 128)  512         conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 56, 56, 128)  0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 56, 56, 32)   36896       activation_3[0][0]               
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 56, 56, 88)   0           concatenate[0][0]                
                                                                 conv2d_4[0][0]                   
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 56, 56, 88)   352         concatenate_1[0][0]              
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 56, 56, 88)   0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 56, 56, 128)  11392       activation_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 56, 56, 128)  512         conv2d_5[0][0]                   
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 56, 56, 128)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 56, 56, 32)   36896       activation_5[0][0]               
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 56, 56, 120)  0           concatenate_1[0][0]              
                                                                 conv2d_6[0][0]                   
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 56, 56, 120)  480         concatenate_2[0][0]              
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 56, 56, 120)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 56, 56, 128)  15488       activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 56, 56, 128)  512         conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 56, 56, 128)  0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 56, 56, 32)   36896       activation_7[0][0]               
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 56, 56, 152)  0           concatenate_2[0][0]              
                                                                 conv2d_8[0][0]                   
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 56, 56, 152)  608         concatenate_3[0][0]              
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 56, 56, 152)  0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 56, 56, 128)  19584       activation_8[0][0]               
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 56, 56, 128)  512         conv2d_9[0][0]                   
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 56, 56, 128)  0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 56, 56, 32)   36896       activation_9[0][0]               
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 56, 56, 184)  0           concatenate_3[0][0]              
                                                                 conv2d_10[0][0]                  
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 56, 56, 184)  736         concatenate_4[0][0]              
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 56, 56, 184)  0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 56, 56, 128)  23680       activation_10[0][0]              
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 56, 56, 128)  512         conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 56, 56, 128)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 56, 56, 32)   36896       activation_11[0][0]              
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 56, 56, 216)  0           concatenate_4[0][0]              
                                                                 conv2d_12[0][0]                  
"""

在上述结果中的concatenate_5,他的输出是concatenate_4+conv2d_12,也就是dense_block中的密集连接(在这里有所体现,也是核心所作)。

注:input layer之后的7*7卷积层的卷积层数为24,论文中提到应设为2*k

参考资料:

https://blog.csdn.net/weixin_44791964/article/details/105472196

你可能感兴趣的:(tensorflow,python)