使用Tensorflow2.x复现uxnet3D

models.py

from nets.model_layers import UXNETBlock, DownSampleBlock, ResBlock3D
from tensorflow import keras

def uxnet3D(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # out1
    embeddings = keras.layers.Conv3D(48,3,2,padding="same", name='embedding')(inputs)

    # out2
    x = UXNETBlock("uxblock_1")(embeddings)
    x = DownSampleBlock("down_1")(x)
    out2 = x
    
    # out3
    x = UXNETBlock("uxblock_2")(x)
    x = DownSampleBlock("down_2")(x)
    out3 = x

    # out4
    x = UXNETBlock("uxblock_3")(x)
    x = DownSampleBlock("down_3")(x)
    out4 = x

    # out5
    x = UXNETBlock("uxblock_4")(x)
    x = DownSampleBlock("down_4")(x)
    out5 = x

    out5 = ResBlock3D(out5.shape[-1], "res_block_5")(out5)
    out5_up = keras.layers.UpSampling3D(name='out5_up')(out5)
    ####################################

    out4_dim = out4.shape[-1]
    out4 = ResBlock3D(out4_dim, "res_block_4_1")(out4)
    out4 = keras.layers.Concatenate(name="cat_out4_out5Up")([out4, out5_up])
    out4 = ResBlock3D(out4_dim, "res_block_4_2")(out4)
    out4_up = keras.layers.UpSampling3D(name='out4_up')(out4)
    ####################################

    out3_dim = out3.shape[-1]
    out3 = ResBlock3D(out3_dim, "res_block_3_1")(out3)
    out3 = keras.layers.Concatenate(name="cat_out3_out4Up")([out3, out4_up])
    out3 = ResBlock3D(out3_dim, "res_block_3_2")(out3)
    out3_up = keras.layers.UpSampling3D(name='out3_up')(out3)    
    ####################################

    out2_dim = out2.shape[-1]
    out2 = ResBlock3D(out2_dim, "res_block_2_1")(out2)
    out2 = keras.layers.Concatenate(name="cat_out2_out3Up")([out2, out3_up])
    out2 = ResBlock3D(out2_dim, "res_block_2_2")(out2)
    out2_up = keras.layers.UpSampling3D(name='out2_up')(out2)   
    ####################################

    out1_dim = embeddings.shape[-1]
    out1 = ResBlock3D(out1_dim, "res_block_1_1")(embeddings)
    out1 = keras.layers.Concatenate(name="cat_out1_out2Up")([out1, out2_up])
    out1 = ResBlock3D(out1_dim, "res_block_1_2")(out1)
    out1_up = keras.layers.UpSampling3D(name='out1_up')(out1)   
    ####################################

    out1 = ResBlock3D(inputs.shape[-1], "res_block_0")(inputs)
    out1 = keras.layers.Concatenate(name="cat_out0_out1Up")([out1, out1_up])
    ####################################

    outputs = keras.layers.Conv3D(num_classes, 1, activation='sigmoid', name='outputs')(out1)

    model = keras.Model(inputs, outputs, name='uxnet3D')

    return model



model = uxnet3D(input_shape=[128,128,32,1],num_classes=2)
model.summary()
# print(len(model.layers))
keras.utils.plot_model(model,to_file=f'{model.name}.png',show_shapes=True)

layers.py


import tensorflow as tf
from tensorflow import keras


class UXNETHalfBlock(keras.layers.Layer):
    def __init__(self, block_name):
        super(UXNETHalfBlock,self).__init__(name=block_name)
        self.block_name = block_name

        self.layer_norm_1 = keras.layers.LayerNormalization()
        self.layer_norm_2 = keras.layers.LayerNormalization()

        self.add1 = keras.layers.Add()
        self.add2 = keras.layers.Add()

    def get_config(self):
        config = super(UXNETHalfBlock, self).get_config()
        config.update(
            {
                "block_name": self.block_name
            }
        )
        return config

    def build(self, input_shape):
        num_filters = input_shape[-1]
        self.depth_wise_conv_1 = keras.layers.Conv3D(num_filters, 7, padding='same')
        self.depth_conv_scale_1 = keras.layers.Conv3D(int(num_filters * 4), 1, padding='same')
        self.depth_conv_scaleBack_1 = keras.layers.Conv3D(num_filters, 1, padding='same')

    
    def call(self, inputs):
        x = self.layer_norm_1(inputs)
        x = self.depth_wise_conv_1(x)

        x1 = self.add1([inputs, x])
        ################################

        x = self.layer_norm_2(x1)
        x = self.depth_conv_scale_1(x)
        x = tf.nn.gelu(x)
        x  = self.depth_conv_scaleBack_1(x)
        x = self.add2([x,x1])
        ################################

        return x


class UXNETBlock(keras.layers.Layer):
    def __init__(self, block_name):
        super(UXNETBlock, self).__init__(name=block_name)

        self.block_name = block_name

        self.block1 = UXNETHalfBlock(block_name=f"{block_name}_part1")
        self.block2 = UXNETHalfBlock(block_name=f"{block_name}_part2")

    def get_config(self):
        config = super(UXNETBlock, self).get_config()
        config.update(
            {
                "block_name": self.block_name
            }
        )
        return config

    def call(self, inputs):
        x = self.block1(inputs)
        x = self.block2(x)

        return x

class DownSampleBlock(keras.layers.Layer):
    def __init__(self, block_name):
        super(DownSampleBlock, self).__init__(name=block_name)

        self.block_name = block_name

    def build(self, input_shape):
        num_filters = int(input_shape[-1] * 2)
        self.conv = keras.layers.Conv3D(num_filters, 3, 2, padding='same')
        self.norm = keras.layers.BatchNormalization()
    
    def get_config(self):
        config = super(DownSampleBlock, self).get_config()
        config.update(
            {
                "block_name": self.block_name
            }
        )
        return config

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.norm(x)
        x = tf.nn.gelu(x)
        return x


class ResBlock3D(keras.layers.Layer): 
    def __init__(self, out_channels, block_name):
        super(ResBlock3D, self).__init__(name=block_name)

        self.add = keras.layers.Add()
        self.out_channels = out_channels
        self.block_name = block_name

    def get_config(self):
        config = super(ResBlock3D,self).get_config()
        config.update(
            {
                "block_name": self.block_name,
                "out_channels": self.out_channels
            }
        )
        return config

    def build(self, input_shape):
        num_filters = input_shape[-1]
        self.conv1 = keras.layers.Conv3D(num_filters, 1, 1, padding='same')
        self.norm1 = keras.layers.BatchNormalization()
        
        self.conv2 = keras.layers.Conv3D(num_filters, 3, 1, padding='same')
        self.norm2 = keras.layers.BatchNormalization()

        self.conv3 = keras.layers.Conv3D(self.out_channels, 1, 1, padding='same')
        self.norm3 = keras.layers.BatchNormalization()

        self.conv4 = keras.layers.Conv3D(self.out_channels, 1, 1, padding='same')
        self.norm4 = keras.layers.BatchNormalization()
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.norm1(x)
        x = tf.nn.gelu(x)

        x = self.conv2(x)
        x = self.norm2(x)
        x = tf.nn.gelu(x)

        x = self.conv3(x)
        x = self.norm3(x)
        x1 = tf.nn.gelu(x)

        x2 = self.conv4(inputs)
        x2= self.norm4(x2)
        x2 = tf.nn.gelu(x2)

        return self.add([x1,x2])
    

你可能感兴趣的:(tensorflow2,tensorflow,人工智能,python)