论文地址:http://www.arxiv.org/pdf/1505.04597.pdf
Unet网络,分割领域的经典之作,大家可以尝试一下。废话少说,上代码。
import tensorflow as tf
def convolutional(input_data, filters_shape, trainable, name, downsample=False, activate=True, bn=True):
with tf.variable_scope(name):
if downsample:
pad_h, pad_w = (filters_shape[0] - 2) // 2 + 1, (filters_shape[1] - 2) // 2 + 1
paddings = tf.constant([[0, 0], [pad_h, pad_h], [pad_w, pad_w], [0, 0]])
input_data = tf.pad(input_data, paddings, 'CONSTANT')
strides = (1, 2, 2, 1)
padding = 'VALID'
else:
strides = (1, 1, 1, 1)
padding = "SAME"
weight = tf.get_variable(name='weight', dtype=tf.float32, trainable=True,
shape=filters_shape, initializer=tf.random_normal_initializer(stddev=0.01))
conv = tf.nn.conv2d(input=input_data, filter=weight, strides=strides, padding=padding)
if bn:
conv = tf.layers.batch_normalization(conv, beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
moving_mean_initializer=tf.zeros_initializer(),
moving_variance_initializer=tf.ones_initializer(), training=trainable)
else:
bias = tf.get_variable(name='bias', shape=filters_shape[-1], trainable=True,
dtype=tf.float32, initializer=tf.constant_initializer(0.0))
conv = tf.nn.bias_add(conv, bias)
if activate is True:
conv = tf.nn.leaky_relu(conv, alpha=0.1)
return conv
def upsample(input_data, name, method="deconv"):
assert method in ["resize", "deconv"]
if method == "resize":
with tf.variable_scope(name):
input_shape = tf.shape(input_data)
output = tf.image.resize_nearest_neighbor(input_data, (input_shape[1] * 2, input_shape[2] * 2))
if method == "deconv":
# replace resize_nearest_neighbor with conv2d_transpose To support TensorRT optimization
numm_filter = input_data.shape.as_list()[-1]
output = tf.layers.conv2d_transpose(input_data, numm_filter//2, kernel_size=4, padding='same',
strides=(2, 2), kernel_initializer=tf.random_normal_initializer())
return output
def Unet(images, filters=8, name='unet'):
with tf.variable_scope(name):
endpoints = {}
conv = convolutional(images, [3, 3, 3, filters], trainable=True, name='conv1')
conv = convolutional(conv, [3, 3, filters, filters], trainable=True, name='conv2')
endpoints['C1'] = conv
# downsample 1
conv = convolutional(conv, [3, 3, filters, filters], trainable=True, name='conv3', downsample=True)
conv = convolutional(conv, [3, 3, filters, filters * 2], trainable=True, name='conv4')
conv = convolutional(conv, [3, 3, filters * 2, filters * 2], trainable=True, name='conv5')
endpoints['C2'] = conv
# downsample 2
conv = convolutional(conv, [3, 3, filters * 2, filters * 2], trainable=True, name='conv6', downsample=True)
conv = convolutional(conv, [3, 3, filters * 2, filters * 4], trainable=True, name='conv7')
conv = convolutional(conv, [3, 3, filters * 4, filters * 4], trainable=True, name='conv8')
endpoints['C3'] = conv
# downsample 3
conv = convolutional(conv, [3, 3, filters * 4, filters * 4], trainable=True, name='conv9', downsample=True)
conv = convolutional(conv, [3, 3, filters * 4, filters * 8], trainable=True, name='conv10')
conv = convolutional(conv, [3, 3, filters * 8, filters * 8], trainable=True, name='conv11')
endpoints['C4'] = conv
# downsample 4
conv = convolutional(conv, [3, 3, filters * 8, filters * 8], trainable=True, name='conv12', downsample=True)
conv = convolutional(conv, [3, 3, filters * 8, filters * 16], trainable=True, name='conv13')
conv = convolutional(conv, [3, 3, filters * 16, filters * 16], trainable=True, name='conv14')
endpoints['C5'] = conv
conv = convolutional(conv, [3, 3, filters * 16, filters * 16], trainable=True, name='conv15', downsample=True)
conv = convolutional(conv, [3, 3, filters * 16, filters * 32], trainable=True, name='conv16')
conv = convolutional(conv, [3, 3, filters * 32, filters * 32], trainable=True, name='conv17')
endpoints['C6'] = conv
conv = convolutional(conv, [3, 3, filters * 32, filters * 32], trainable=True, name='conv18', downsample=True)
conv = convolutional(conv, [3, 3, filters * 32, filters * 64], trainable=True, name='conv19')
conv = convolutional(conv, [3, 3, filters * 64, filters * 64], trainable=True, name='conv20')
endpoints['C7'] = conv
for i in range(7, 1, -1):
with tf.variable_scope('Ronghe%d' % i):
uplayer = upsample(conv, 'deconv%d' % (8-i), method="deconv")
concat = tf.concat([endpoints['C%d' % (i-1)], uplayer], axis=-1)
dim = concat.get_shape()[-1].value
conv = convolutional(concat, [3, 3, dim, dim//2], trainable=True, name='conv1')
conv = convolutional(conv, [3, 3, dim//2, dim//2], trainable=True, name='conv2')
out = convolutional(conv, [3, 3, dim//2, 1], trainable=True, name='out', activate=False, bn=False)
return out