Unet 结构神经网络是通过卷积进行采样的,属于卷积神经网络的一种。在2015年在文章 U-Net: Convolutional Networks for Biomedical Image Segmentation 中被提出。
采用3D卷积的方式实现 Unet 网络结构,原因是单个影像是 4 维的核磁共振影像,tensorflow2 实现的模型输入是 一定数量的影像,所以模型输入是一个 5 维的张量,张量形状是 (影像数量,影像维度1,影像维度2,影像维度3,影像维度4)。
层数 | 神经网络层 | 卷积核形状 | 输出张量形状 | |
Encoder | 1 | batchnormalization | (batch size , 240, 240, 155, 4) | |
2 | conv3d | (3,3,3) | (batch size, 240, 240, 155, 8) | |
3 | batchnormalization | (batch size, 240, 240, 155, 8) | ||
4 | conv3d | (3,3,3) | (batch size, 240, 240, 155, 16) | |
5 | batchnormalization | (batch size, 240, 240, 155, 16) | ||
6 | conv3d | (3,3,2) | (batch size, 238, 238, 155, 16) | |
7 | batchnormalization | (batch size, 238, 238, 155, 16) | ||
8 | conv3d | (3,3,1) | (batch size, 118, 118, 77, 32) | |
9 | batchnormalization | (batch size, 118, 118, 77, 32) | ||
10 | conv3d | (3,3,1) | (batch size, 58, 58, 39, 64) | |
11 | batchnormalization | (batch size, 58, 58, 39, 64) | ||
12 | maxpooling3d | (2,2,1) | (batch size, 29, 29, 39, 64) | |
Decoder | 13 | batchnormalization | (batch size,29,29,39,64) | |
14 | upsampling3d | (2,2,1) | (batch size,58,58,39,64) | |
15 | conv3dTranspose | (3,3,1) | (batch size,58,58,39,32) | |
16 | concat | (batch size,58,58,39,128) | ||
17 | batchnormalization | (batch size,58,58,39,32) | ||
18 | upsampling3d | (2,2,2) | (batch size,116,116,78,64) | |
19 | conv3dTranspose | (3,3,1) | (batch size,118,118,78,32) | |
20 | conv3d | (3,3,3) | (batch size,116,116,76,32) | |
21 | batchnormalization | (batch size,116,116,76,32) | ||
22 | conv3dTranspose | (3,3,2) | (batch size,118,118,77,16) | |
23 | concat | (batch size,118,118,77,32) | ||
24 | batchnormalization | (batch size,118,118,77,16) | ||
25 | upsampling3d | (2,2,2) | (batch size,236,236,154,16) | |
26 | conv3dTranspose | (3,3,1) | (batch size,238,238,154,16) | |
27 | concat | (batch size,238,238,154,32) | ||
28 | batchnormalization | (batch size,238,238,154,16) | ||
29 | conv3dTranspose | (3,3,1) | (batch size,240,240,154,8) | |
30 | conv3dTranspose | (1,1,5) | (batch size,240,240,158,4) | |
31 | conv3d | (1,1,4) | (batch size,240,240,155,1) |
from tensorflow.keras.layers import BatchNormalization,Conv3D,MaxPooling3D,Conv3DTranspose,UpSampling3D
import tensorflow as tf
class unet_encoder(tf.keras.Model):
def __init__(self):
self.b1 = BatchNormalization()
self.conv1 = Conv3D(8,3,activation='relu',padding='same')
self.b2 = BatchNormalization()
self.conv2 = Conv3D(16,3,activation='relu',padding='same')
self.b3 = BatchNormalization()
self.conv3 = Conv3D(16,(3,3,2),activation='relu')
self.b4 = BatchNormalization()
self.conv4 = Conv3D(32,(3,3,1),activation='relu',strides=2)
self.b5 = BatchNormalization()
self.conv5 = Conv3D(64,(3,3,1),activation='relu',strides=2)
self.b6 = BatchNormalization()
self.maxpool1 = MaxPooling3D((2,2,1))
def call(self,x,features):
x = self.b1(x)
x = self.conv1(x)
x = self.b2(x)
x = self.conv2(x)
x = self.b3(x)
# 第一个连接特征图
x = self.conv3(x)
x = self.b4(x)
# 第二个连接特征图
x = self.conv4(x)
x = self.b5(x)
# 第三个连接特征图
x = self.conv5(x)
x = self.b6(x)
# 输出变量
outputs = self.maxpool1(x)
return outputs
class unet_decoder(tf.keras.Model):
def __init__(self):
self.b1 = BatchNormalization()
self.up1 = UpSampling3D((2,2,1))
self.conv1tp = Conv3DTranspose(64,(3,3,1),activation='relu',padding='same')
self.b2 = BatchNormalization()
self.up2 = UpSampling3D((2,2,2))
self.conv2tp = Conv3DTranspose(32,(3,3,1),activation='relu')
self.conv2 = Conv3D(32,3,activation='relu')
self.b3 = BatchNormalization()
self.conv3tp = Conv3DTranspose(16,(3,3,2),activation='relu')
self.b4 = BatchNormalization()
self.up4 = UpSampling3D((2,2,2))
self.conv4tp = Conv3DTranspose(16,(3,3,1),activation='relu')
self.b5 = BatchNormalization()
self.conv5tp = Conv3DTranspose(8,(3,3,1),activation='relu')
self.conv6tp = Conv3DTranspose(4,(1,1,5),activation='relu')
self.conv_out = Conv3D(1,(1,1,4),activation='relu')
def call(self,x,features):
x = self.b1(x)
x = self.up1(x)
x = self.conv1tp(x)
x = tf.concat((features[-1],x),axis=-1)
x = self.b2(x)
x = self.up2(x)
x = self.conv2tp(x)
x = self.conv2(x)
x = self.b3(x)
x = self.conv3tp(x)
x = tf.concat((features[-2],x),axis=-1)
x = self.b4(x)
x = self.up4(x)
x = self.conv4tp(x)
x = tf.concat((features[-3],x),axis=-1)
x = self.b5(x)
x = self.conv5tp(x)
x = self.conv6tp(x)
x = self.conv_out(x)
outputs = x
return outputs
class Unet3D(tf.keras.Model):
def __init__(self,encoder,decoder):
self.features = []
self.encoder = encoder
self.decoder = decoder
def call(self,x):
x = self.encoder(x,self.features)
outputs = self.decoder(x,self.features)
return outputs
软件 | 版本 |
Python | 3.8.11 |
Tensorflow | 2.6.0-gpu |
CUDA | 11.2 |
cuDNN | 8.1.0 |
nibabel | 3.2.2 |
处理器 | 型号 | 显存 |
GPU | NVIDIA Geforce GTX 3090 | 24G |
MSD脑瘤数据集(百度飞桨 AI Studio)
import nibabel as nib
import numpy as np
def nearest_4d(img,size):
res = np.zeros(size)
for i in range(res.shape[0]):
for j in range(res.shape[1]):
for k in range(res.shape[2]):
idx = i*img.shape[0] // res.shape[0]
idy = j*img.shape[1] // res.shape[1]
idz = k*img.shape[2] // res.shape[2]
res[i,j,k,:] = img[idx,idy,idz,:]
return res
# 按照数据文件路径以迭代器的方式读取数据
class DataIterator:
def __init__(self,image_paths,label_paths,size=None,transp_shape=[0,1,2,3],mode='nib'):
self.image_paths = image_paths
self.label_paths = label_paths
self.size = size
self.transp = transp_shape
def read_and_resize(self,img_path,lbl_path):
if self.mode=='nib':
img = nib.load(img_path)
lbl = nib.load(lbl_path)
img = img.get_fdata(caching='fill', dtype='float32')
lbl = lbl.get_fdata(caching='fill', dtype='float32')
elif self.mode == 'np':
img = np.load(img_path)
lbl = np.load(lbl_path)
return None,None
img /= np.max(img)
lbl /= np.max(lbl)
img = img.transpose(self.transp)
if len(lbl.shape)<len(img.shape):
lbl = np.expand_dims(lbl,axis=-1)
lbl = lbl.transpose(self.transp)
if self.size != None:
if len(self.size) == 3:
img = nearest_3d(img,self.size)
lbl = nearest_3d(lbl,self.size)
img = nearest_4d(img,self.size)
lbl = nearest_4d(lbl,self.size)
return img,lbl
def __iter__(self):
for img_path,lbl_path in zip(self.image_paths,self.label_paths):
img,lbl = self.read_and_resize(img_path,lbl_path)
if isinstance(img,np.ndarray) and isinstance(lbl,np.ndarray):
yield (img,lbl)
# 数据生成器,因为训练用的标签数据少了一个维度,所以在返回数据对象之前给数据对象扩充维度
class DataGenerator:
def __init__(self,image_paths,label_paths,size=None,batch_size=32,transp_shape=[0,1,2,3],mode='nib'):
dataiter = DataIterator(image_paths,label_paths,size,transp_shape,mode)
self.batch_size = batch_size
self.dataiter = iter(dataiter)
def __iter__(self):
while 1:
i = 0
imgs = []
lbls = []
for img,lbl in self.dataiter:
i += 1
if i >= self.batch_size:
if i == 0:
imgs = np.stack(imgs)
lbls = np.stack(lbls)
if len(imgs.shape) < 5:
imgs = np.expand_dims(imgs,axis=-1)
lbls = np.expand_dims(lbls,axis=-1)
yield (imgs,lbls)
import tensorflow as tf
from tensorflow.keras import losses,optimizers
from model import unet_encoder,unet_decoder,Unet3D
from DataGenerator import DataGenerator
from datetime import datetime
from time import time
import os
# 数据路径
image_dir_path = './data/train/'
label_dir_path = './data/labels/'
images_paths = os.listdir(image_dir_path)
labels_paths = os.listdir(label_dir_path)
image_paths = [image_dir_path+p for p in images_paths]
label_paths = [label_dir_path+p for p in labels_paths]
# 日志记录文件
log1 = open('./log/epoch_file_form','w',encoding='utf-8')
log2 = open('./log/step_file_form','w',encoding='utf-8')
date_mark = str(datetime.now())
# 模型定义
encoder_model = unet_encoder()
decoder_model = unet_decoder()
unet = Unet3D(encoder_model,decoder_model)
# 设置优化器,损失函数
optimizer = optimizers.Adam(learning_rate=1e-5)
losser = losses.BinaryCrossentropy()
# 训练
epochs = 30
s1 = time()
for i in range(epochs):
s2 = time()
loss_sum = 0
step = 0
datagener = iter(DataGenerator(image_paths,label_paths,None,1,[0,1,2,3]))
for batch in datagener:
s3 = time()
step += 1
x = batch[0]
y = batch[1]
with tf.GradientTape() as tape:
out = unet(x)
loss = losser(y_pred=out,y_true=y)
grads = tape.gradient(loss,unet.trainable_variables)
e3 = time()
loss_sum += loss
info_step = f'step:{step:03}\tloss:{loss}\t running time: {e3-s3:.3f} s'
print(' ',end='\r')
e2 = time()
avg_loss = loss_sum/step if step != 0 else 'non samples'
info_epoch = f'epoch {i+1:02}\t average loss {avg_loss}\t running time {e2-s2:.3f} s'
print(' ',end='\r')
e1 = time()
all_time = f'Training time {e1-s1:.3f} s'
log1.write(all_time+' s\n')
log2.write(all_time+' s\n')
# 保存模型
使用两块GPU,将 encoder 部分放置到GPU0上,decoder部分放置到GPU1上。
from tensorflow.keras.layers import BatchNormalization,Conv3D,MaxPooling3D,Conv3DTranspose,UpSampling3D
import tensorflow as tf
def copy_tensor_to_gpu(tensor,gpu_id):
with tf.device(f'/gpu: {gpu_id}'):
res = tf.zeros_like(tensor)
res = res + tensor
return res
def copy_tensor_to_cpu(tensor,cpu_id):
with tf.device(f'/cpu: {cpu_id}'):
res = tf.zeros_like(tensor,cpu_id)
res = res + tensor
return res
class unet_encoder(tf.keras.Model):
def __init__(self):
self.b1 = BatchNormalization()
self.conv1 = Conv3D(8,3,activation='relu',padding='same')
self.b2 = BatchNormalization()
self.conv2 = Conv3D(16,3,activation='relu',padding='same')
self.b3 = BatchNormalization()
self.conv3 = Conv3D(16,(3,3,2),activation='relu')
self.b4 = BatchNormalization()
self.conv4 = Conv3D(32,(3,3,1),activation='relu',strides=2)
self.b5 = BatchNormalization()
self.conv5 = Conv3D(64,(3,3,1),activation='relu',strides=2)
self.b6 = BatchNormalization()
self.maxpool1 = MaxPooling3D((2,2,1))
def call(self,x,features,gpu_id):
x = self.b1(x)
x = self.conv1(x)
x = self.b2(x)
x = self.conv2(x)
x = self.b3(x)
# 第一个连接特征图
x = self.conv3(x)
x = self.b4(x)
features[0] = copy_tensor_to_gpu(x,gpu_id)
# 第二个连接特征图
x = self.conv4(x)
x = self.b5(x)
features[1] = copy_tensor_to_gpu(x,gpu_id)
# 第三个连接特征图
x = self.conv5(x)
x = self.b6(x)
features[2] = copy_tensor_to_gpu(x,gpu_id)
# 输出变量
outputs = self.maxpool1(x)
return outputs
class unet_decoder(tf.keras.Model):
def __init__(self):
self.b1 = BatchNormalization()
self.up1 = UpSampling3D((2,2,1))
self.conv1tp = Conv3DTranspose(64,(3,3,1),activation='relu',padding='same')
self.b2 = BatchNormalization()
self.up2 = UpSampling3D((2,2,2))
self.conv2tp = Conv3DTranspose(32,(3,3,1),activation='relu')
self.conv2 = Conv3D(32,3,activation='relu')
self.b3 = BatchNormalization()
self.conv3tp = Conv3DTranspose(16,(3,3,2),activation='relu')
self.b4 = BatchNormalization()
self.up4 = UpSampling3D((2,2,2))
self.conv4tp = Conv3DTranspose(16,(3,3,1),activation='relu')
self.b5 = BatchNormalization()
self.conv5tp = Conv3DTranspose(8,(3,3,1),activation='relu')
self.conv6tp = Conv3DTranspose(4,(1,1,5),activation='relu')
self.conv_out = Conv3D(1,(1,1,4),activation='relu')
def call(self,x,features):
x = self.b1(x)
x = self.up1(x)
x = self.conv1tp(x)
x = tf.concat((features[-1],x),axis=-1)
x = self.b2(x)
x = self.up2(x)
x = self.conv2tp(x)
x = self.conv2(x)
x = self.b3(x)
x = self.conv3tp(x)
x = tf.concat((features[-2],x),axis=-1)
x = self.b4(x)
x = self.up4(x)
x = self.conv4tp(x)
x = tf.concat((features[-3],x),axis=-1)
x = self.b5(x)
x = self.conv5tp(x)
x = self.conv6tp(x)
x = self.conv_out(x)
outputs = x
return outputs
class Unet3DParallel(tf.keras.Model):
def __init__(self,gpu_group):
self.gpus = gpu_group
with tf.device(f'/gpu:{gpu_group[1]}'):
self.features = [None for i in range(3)]
with tf.device(f'/gpu:{gpu_group[0]}'):
self.encoder = unet_encoder()
with tf.device(f'/gpu:{gpu_group[1]}'):
self.decoder = unet_decoder()
def call(self,x):
x = self.encoder(x,self.features,self.gpus[1])
outputs = self.decoder(x,self.features)
return outputs
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
# 模型可训练参数统计
Model: "unet3d_parallel"
Layer (type) Output Shape Param #
unet_encoder (unet_encoder) multiple 32664
unet_decoder (unet_decoder) multiple 121373
Total params: 154,037
Trainable params: 153,149
Non-trainable params: 888
# 显卡使用情况监控
| NVIDIA-SMI 470.94 Driver Version: 470.94 CUDA Version: 11.4 |
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
| 0 NVIDIA GeForce ... On | 00000000:3E:00.0 Off | N/A |
| 59% 61C P2 211W / 350W | 23746MiB / 24268MiB | 67% Default |
| | | N/A |
| 1 NVIDIA GeForce ... On | 00000000:88:00.0 Off | N/A |
| 46% 56C P2 120W / 350W | 4504MiB / 24268MiB | 22% Default |
| | | N/A |