https://github.com/zonghaofan/pig-seg/tree/master/disk_segmentation
网络架构:
# coding:utf-8
import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as plt
img = cv2.imread('./data/test.png')
# cv2.imshow('1.jpg',img)
# cv2.waitKey(0)
img = cv2.resize(img, (1024, 1024))
img = np.array(img).astype(np.float32)
img = img[np.newaxis, ...]
print(img.shape)
x_input = tf.placeholder(shape=[None, 1024, 1024, 3], dtype=tf.float32)
# x=tf.random_normal(shape=[1,1024,1024,3],dtype=tf.float32)
n_filters = [8, 8]
# name=1
def conv2d(x, n_filters, training, name, pool=True, activation=tf.nn.relu):
with tf.variable_scope('layer{}'.format(name)):
for index, filter in enumerate(n_filters):
conv = tf.layers.conv2d(x, filter, (3, 3), strides=1, padding='same', activation=None,
name='conv_{}'.format(index + 1))
conv = tf.layers.batch_normalization(conv, training=training, name='bn_{}'.format(index + 1))
conv = activation(conv, name='relu{}_{}'.format(name, index + 1))
if pool is False:
return conv
pool = tf.layers.max_pooling2d(conv, pool_size=(2, 2), strides=2, name='pool_{}'.format(name))
return conv, pool
def upsampling_2d(tensor, name, size=(2, 2)):
h_, w_, c_ = tensor.get_shape().as_list()[1:]
h_multi, w_multi = size
h = h_multi * h_
w = w_multi * w_
target = tf.image.resize_nearest_neighbor(tensor, size=(h, w), name='upsample_{}'.format(name))
return target
def upsampling_concat(input_A, input_B, name):
upsampling = upsampling_2d(input_A, name=name, size=(2, 2))
up_concat = tf.concat([upsampling, input_B], axis=-1, name='up_concat_{}'.format(name))
return up_concat
def unet(input):
#归一化 -1~1
input=(input-127.5)/127.5
conv1, pool1 = conv2d(input, [8, 8], training=True, name=1)
print(conv1.shape)
print(pool1.shape)
conv2, pool2 = conv2d(pool1, [16, 16], training=True, name=2)
print(conv2.shape)
print(pool2.shape)
conv3, pool3 = conv2d(pool2, [32, 32], training=True, name=3)
print(conv3.shape)
print(pool3.shape)
conv4, pool4 = conv2d(pool3, [64, 64], training=True, name=4)
print(conv4.shape)
print(pool4.shape)
conv5 = conv2d(pool4, [128, 128], training=True, pool=False, name=5)
print(conv5.shape)
up6 = upsampling_concat(conv5, conv4, name=6)
print('up6', up6.shape)
conv6 = conv2d(up6, [64, 64], training=True, pool=False, name=6)
print(conv6.shape)
up7 = upsampling_concat(conv6, conv3, name=7)
print('up7', up7.shape)
conv7 = conv2d(up7, [32, 32], training=True, pool=False, name=7)
print(conv7.shape)
up8 = upsampling_concat(conv7, conv2, name=8)
print('up8', up8.shape)
conv8 = conv2d(up8, [16, 16], training=True, pool=False, name=8)
print(conv8.shape)
up9 = upsampling_concat(conv8, conv1, name=9)
print('up9', up9.shape)
conv9 = conv2d(up9, [8, 8], training=True, pool=False, name=9)
print(conv9.shape)
final = tf.layers.conv2d(conv9, 1, (1, 1), name='final', activation=tf.nn.sigmoid, padding='same')
print('final', final.shape)
return final
if __name__ == '__main__':
final=unet(x_input)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
y_final = sess.run(final, feed_dict={x_input: img})
result = y_final[0, ...]
print(result.shape)
print(result[...,:10])
# result=cv2.imread('./2.jpg')
# result=cv2.resize(result,(640,640))
# print(result)
cv2.imshow('1.jpg', result)
cv2.waitKey(0)
打印结果:这里打印值有小数,故直接imshow就是输出图,而如果imwrite,查看图片的值全是0,1,轮廓也能看清,只不过不是很清晰。
输入:
输出:截图没有完全