# -*-coding: utf-8 -*-
"""
@Project: triple_path_networks
@File : UNet.py
@Author : panjq
@E-mail : [email protected]
@Date : 2019-01-24 11:18:15
"""
import tensorflow as tf
import tensorflow.contrib.slim as slim
def lrelu(x):
return tf.maximum(x * 0.2, x)
activation_fn=lrelu
def UNet(inputs, reg): # Unet
conv1 = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_1', weights_regularizer=reg)
conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_2',weights_regularizer=reg)
pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')
conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_1',weights_regularizer=reg)
conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_2',weights_regularizer=reg)
pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')
conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_1',weights_regularizer=reg)
conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_2',weights_regularizer=reg)
pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')
conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_1',weights_regularizer=reg)
conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_2',weights_regularizer=reg)
pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')
conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_1',weights_regularizer=reg)
conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_2',weights_regularizer=reg)
up6 = upsample_and_concat(conv5, conv4, 256, 512)
conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_1',weights_regularizer=reg)
conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_2',weights_regularizer=reg)
up7 = upsample_and_concat(conv6, conv3, 128, 256)
conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_1',weights_regularizer=reg)
conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_2',weights_regularizer=reg)
up8 = upsample_and_concat(conv7, conv2, 64, 128)
conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_1',weights_regularizer=reg)
conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_2',weights_regularizer=reg)
up9 = upsample_and_concat(conv8, conv1, 32, 64)
conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_1', weights_regularizer=reg)
conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_2',weights_regularizer=reg)
print("conv9.shape:{}".format(conv9.get_shape()))
type='UNet_1X'
with tf.variable_scope(name_or_scope="output"):
if type=='UNet_3X':#UNet放大三倍
conv10 = slim.conv2d(conv9, 27, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)
out = tf.depth_to_space(conv10, 3)
if type=='UNet_1X':#输入输出维度相同
out = slim.conv2d(conv9, 6, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)
return out
def upsample_and_concat(x1, x2, output_channels, in_channels):
pool_size = 2
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])
deconv_output = tf.concat([deconv, x2], 3)
deconv_output.set_shape([None, None, None, output_channels * 2])
return deconv_output
if __name__=="__main__":
weight_decay=0.001
reg = slim.l2_regularizer(scale=weight_decay)
inputs = tf.ones(shape=[4, 100, 200, 3])
out=UNet(inputs,reg)
print("net1.shape:{}".format(inputs.get_shape()))
print("out.shape:{}".format(out.get_shape()))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())