models.py
from nets.model_layers import UXNETBlock, DownSampleBlock, ResBlock3D
from tensorflow import keras
def uxnet3D(input_shape, num_classes):
inputs = keras.Input(shape=input_shape)
embeddings = keras.layers.Conv3D(48,3,2,padding="same", name='embedding')(inputs)
x = UXNETBlock("uxblock_1")(embeddings)
x = DownSampleBlock("down_1")(x)
out2 = x
x = UXNETBlock("uxblock_2")(x)
x = DownSampleBlock("down_2")(x)
out3 = x
x = UXNETBlock("uxblock_3")(x)
x = DownSampleBlock("down_3")(x)
out4 = x
x = UXNETBlock("uxblock_4")(x)
x = DownSampleBlock("down_4")(x)
out5 = x
out5 = ResBlock3D(out5.shape[-1], "res_block_5")(out5)
out5_up = keras.layers.UpSampling3D(name='out5_up')(out5)
out4_dim = out4.shape[-1]
out4 = ResBlock3D(out4_dim, "res_block_4_1")(out4)
out4 = keras.layers.Concatenate(name="cat_out4_out5Up")([out4, out5_up])
out4 = ResBlock3D(out4_dim, "res_block_4_2")(out4)
out4_up = keras.layers.UpSampling3D(name='out4_up')(out4)
out3_dim = out3.shape[-1]
out3 = ResBlock3D(out3_dim, "res_block_3_1")(out3)
out3 = keras.layers.Concatenate(name="cat_out3_out4Up")([out3, out4_up])
out3 = ResBlock3D(out3_dim, "res_block_3_2")(out3)
out3_up = keras.layers.UpSampling3D(name='out3_up')(out3)
out2_dim = out2.shape[-1]
out2 = ResBlock3D(out2_dim, "res_block_2_1")(out2)
out2 = keras.layers.Concatenate(name="cat_out2_out3Up")([out2, out3_up])
out2 = ResBlock3D(out2_dim, "res_block_2_2")(out2)
out2_up = keras.layers.UpSampling3D(name='out2_up')(out2)
out1_dim = embeddings.shape[-1]
out1 = ResBlock3D(out1_dim, "res_block_1_1")(embeddings)
out1 = keras.layers.Concatenate(name="cat_out1_out2Up")([out1, out2_up])
out1 = ResBlock3D(out1_dim, "res_block_1_2")(out1)
out1_up = keras.layers.UpSampling3D(name='out1_up')(out1)
out1 = ResBlock3D(inputs.shape[-1], "res_block_0")(inputs)
out1 = keras.layers.Concatenate(name="cat_out0_out1Up")([out1, out1_up])
outputs = keras.layers.Conv3D(num_classes, 1, activation='sigmoid', name='outputs')(out1)
model = keras.Model(inputs, outputs, name='uxnet3D')
return model
model = uxnet3D(input_shape=[128,128,32,1],num_classes=2)
model.summary()
keras.utils.plot_model(model,to_file=f'{model.name}.png',show_shapes=True)
layers.py
import tensorflow as tf
from tensorflow import keras
class UXNETHalfBlock(keras.layers.Layer):
def __init__(self, block_name):
super(UXNETHalfBlock,self).__init__(name=block_name)
self.block_name = block_name
self.layer_norm_1 = keras.layers.LayerNormalization()
self.layer_norm_2 = keras.layers.LayerNormalization()
self.add1 = keras.layers.Add()
self.add2 = keras.layers.Add()
def get_config(self):
config = super(UXNETHalfBlock, self).get_config()
config.update(
{
"block_name": self.block_name
}
)
return config
def build(self, input_shape):
num_filters = input_shape[-1]
self.depth_wise_conv_1 = keras.layers.Conv3D(num_filters, 7, padding='same')
self.depth_conv_scale_1 = keras.layers.Conv3D(int(num_filters * 4), 1, padding='same')
self.depth_conv_scaleBack_1 = keras.layers.Conv3D(num_filters, 1, padding='same')
def call(self, inputs):
x = self.layer_norm_1(inputs)
x = self.depth_wise_conv_1(x)
x1 = self.add1([inputs, x])
x = self.layer_norm_2(x1)
x = self.depth_conv_scale_1(x)
x = tf.nn.gelu(x)
x = self.depth_conv_scaleBack_1(x)
x = self.add2([x,x1])
return x
class UXNETBlock(keras.layers.Layer):
def __init__(self, block_name):
super(UXNETBlock, self).__init__(name=block_name)
self.block_name = block_name
self.block1 = UXNETHalfBlock(block_name=f"{block_name}_part1")
self.block2 = UXNETHalfBlock(block_name=f"{block_name}_part2")
def get_config(self):
config = super(UXNETBlock, self).get_config()
config.update(
{
"block_name": self.block_name
}
)
return config
def call(self, inputs):
x = self.block1(inputs)
x = self.block2(x)
return x
class DownSampleBlock(keras.layers.Layer):
def __init__(self, block_name):
super(DownSampleBlock, self).__init__(name=block_name)
self.block_name = block_name
def build(self, input_shape):
num_filters = int(input_shape[-1] * 2)
self.conv = keras.layers.Conv3D(num_filters, 3, 2, padding='same')
self.norm = keras.layers.BatchNormalization()
def get_config(self):
config = super(DownSampleBlock, self).get_config()
config.update(
{
"block_name": self.block_name
}
)
return config
def call(self, inputs):
x = self.conv(inputs)
x = self.norm(x)
x = tf.nn.gelu(x)
return x
class ResBlock3D(keras.layers.Layer):
def __init__(self, out_channels, block_name):
super(ResBlock3D, self).__init__(name=block_name)
self.add = keras.layers.Add()
self.out_channels = out_channels
self.block_name = block_name
def get_config(self):
config = super(ResBlock3D,self).get_config()
config.update(
{
"block_name": self.block_name,
"out_channels": self.out_channels
}
)
return config
def build(self, input_shape):
num_filters = input_shape[-1]
self.conv1 = keras.layers.Conv3D(num_filters, 1, 1, padding='same')
self.norm1 = keras.layers.BatchNormalization()
self.conv2 = keras.layers.Conv3D(num_filters, 3, 1, padding='same')
self.norm2 = keras.layers.BatchNormalization()
self.conv3 = keras.layers.Conv3D(self.out_channels, 1, 1, padding='same')
self.norm3 = keras.layers.BatchNormalization()
self.conv4 = keras.layers.Conv3D(self.out_channels, 1, 1, padding='same')
self.norm4 = keras.layers.BatchNormalization()
def call(self, inputs):
x = self.conv1(inputs)
x = self.norm1(x)
x = tf.nn.gelu(x)
x = self.conv2(x)
x = self.norm2(x)
x = tf.nn.gelu(x)
x = self.conv3(x)
x = self.norm3(x)
x1 = tf.nn.gelu(x)
x2 = self.conv4(inputs)
x2= self.norm4(x2)
x2 = tf.nn.gelu(x2)
return self.add([x1,x2])