tensorflow 卷积、反卷积形式的去噪自编码器

tensorflow 卷积、反卷积形式的去噪自编码器

对于去噪自编码器,网上好多都是利用全连接神经网络进行构建,我自己写了一个卷积、反卷积形式的去噪自编码器,其中的参数调优如果有兴趣的话,可以自行修改查看结果。


数据集我使用最简单的mnist:



网络结构:

mnist输入(28*28=784向量) => 28*28*1矩阵 => 卷积层1 => 14*14*64 => 卷积层2 => 7*7*64 => 卷积层3 => 4*4*32 => 反卷积层1 => 7×7*32 => 反卷积层2 => 14*14*64 => 反卷积层3 => 28*28*64 => 卷积层X => 28×28*1


训练:

我用train集训练train_epochs轮,然后用test集对训练好的模型进行评测,同时保存加噪图像及对应的去噪图像。


Code:

  1. #! -*- coding: utf-8 -*-  
  2.   
  3. ## by Colie (lijixiang)  
  4.   
  5. import tensorflow as tf  
  6. from tensorflow.examples.tutorials.mnist import input_data  
  7. import numpy as np  
  8. from PIL import Image  
  9.   
  10. train_epochs = 35  ## int(1e5+1)  
  11.   
  12. INPUT_HEIGHT = 28  
  13. INPUT_WIDTH = 28  
  14.   
  15. batch_size = 256  
  16.   
  17. noise_factor = 0.5  ## (0~1)  
  18.   
  19. ## 原始输入是28×28*3  
  20. input_x = tf.placeholder(tf.float32, [None, INPUT_HEIGHT * INPUT_WIDTH], name='input_with_noise')  
  21. input_matrix = tf.reshape(input_x, shape=[-1, INPUT_HEIGHT, INPUT_WIDTH, 1])  
  22. input_raw = tf.placeholder(tf.float32, shape=[None, INPUT_HEIGHT * INPUT_WIDTH], name='input_without_noise')  
  23.   
  24. ## 1 conv layer  
  25. ## 输入28*28*3  
  26. ## 经过卷积、激活、池化,输出14*14*64  
  27. weight_1 = tf.Variable(tf.truncated_normal(shape=[33164], stddev=0.1, name = 'weight_1'))  
  28. bias_1 = tf.Variable(tf.constant(0.0, shape=[64], name='bias_1'))  
  29. conv1 = tf.nn.conv2d(input=input_matrix, filter=weight_1, strides=[1111], padding='SAME')  
  30. conv1 = tf.nn.bias_add(conv1, bias_1, name='conv_1')  
  31. acti1 = tf.nn.relu(conv1, name='acti_1')  
  32. pool1 = tf.nn.max_pool(value=acti1, ksize=[1221], strides=[1221], padding='SAME', name='max_pool_1')  
  33.   
  34. ## 2 conv layer  
  35. ## 输入14*14*64  
  36. ## 经过卷积、激活、池化,输出7×7×64  
  37. weight_2 = tf.Variable(tf.truncated_normal(shape=[336464], stddev=0.1, name='weight_2'))  
  38. bias_2 = tf.Variable(tf.constant(0.0, shape=[64], name='bias_2'))  
  39. conv2 = tf.nn.conv2d(input=pool1, filter=weight_2, strides=[1111], padding='SAME')  
  40. conv2 = tf.nn.bias_add(conv2, bias_2, name='conv_2')  
  41. acti2 = tf.nn.relu(conv2, name='acti_2')  
  42. pool2 = tf.nn.max_pool(value=acti2, ksize=[1221], strides=[1221], padding='SAME', name='max_pool_2')  
  43.   
  44. ## 3 conv layer  
  45. ## 输入7*7*64  
  46. ## 经过卷积、激活、池化,输出4×4×32  
  47. ## 原始输入是28*28*3=2352,转化为4*4*32=512,大量噪声会在网络中过滤掉  
  48. weight_3 = tf.Variable(tf.truncated_normal(shape=[336432], stddev=0.1, name='weight_3'))  
  49. bias_3 = tf.Variable(tf.constant(0.0, shape=[32]))  
  50. conv3 = tf.nn.conv2d(input=pool2, filter=weight_3, strides=[1111], padding='SAME')  
  51. conv3 = tf.nn.bias_add(conv3, bias_3)  
  52. acti3 = tf.nn.relu(conv3, name='acti_3')  
  53. pool3 = tf.nn.max_pool(value=acti3, ksize=[1221], strides=[1221], padding='SAME', name='max_pool_3')  
  54.   
  55. ## 1 deconv layer  
  56. ## 输入4*4*32  
  57. ## 经过反卷积,输出7*7*32  
  58. deconv_weight_1 = tf.Variable(tf.truncated_normal(shape=[333232], stddev=0.1), name='deconv_weight_1')  
  59. deconv1 = tf.nn.conv2d_transpose(value=pool3, filter=deconv_weight_1, output_shape=[batch_size, 7732], strides=[1221], padding='SAME', name='deconv_1')  
  60.   
  61. ## 2 deconv layer  
  62. ## 输入7*7*32  
  63. ## 经过反卷积,输出14*14*64  
  64. deconv_weight_2 = tf.Variable(tf.truncated_normal(shape=[336432], stddev=0.1), name='deconv_weight_2')  
  65. deconv2 = tf.nn.conv2d_transpose(value=deconv1, filter=deconv_weight_2, output_shape=[batch_size, 141464], strides=[1221], padding='SAME', name='deconv_2')  
  66.   
  67. ## 3 deconv layer  
  68. ## 输入14*14*64  
  69. ## 经过反卷积,输出28*28*64  
  70. deconv_weight_3 = tf.Variable(tf.truncated_normal(shape=[336464], stddev=0.1, name='deconv_weight_3'))  
  71. deconv3 = tf.nn.conv2d_transpose(value=deconv2, filter=deconv_weight_3, output_shape=[batch_size, 282864], strides=[1221], padding='SAME', name='deconv_3')  
  72.   
  73. ## conv layer  
  74. ## 输入28*28*64  
  75. ## 经过卷积,输出为28*28*1  
  76. weight_final = tf.Variable(tf.truncated_normal(shape=[33641], stddev=0.1, name = 'weight_final'))  
  77. bias_final = tf.Variable(tf.constant(0.0, shape=[1], name='bias_final'))  
  78. conv_final = tf.nn.conv2d(input=deconv3, filter=weight_final, strides=[1111], padding='SAME')  
  79. conv_final = tf.nn.bias_add(conv_final, bias_final, name='conv_final')  
  80.   
  81. ## output  
  82. ## 输入28*28*1  
  83. ## reshape为28*28  
  84. output = tf.reshape(conv_final, shape=[-1, INPUT_HEIGHT * INPUT_WIDTH])  
  85.   
  86. ## loss and optimizer  
  87. loss = tf.reduce_mean(tf.pow(tf.subtract(output, input_raw), 2.0))  
  88. optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)  
  89.   
  90.   
  91. with tf.Session() as sess:  
  92.   
  93.     mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  
  94.     n_samples = int(mnist.train.num_examples)  
  95.     print('train samples: %d' % n_samples)  
  96.     print('batch size: %d' % batch_size)  
  97.     total_batch = int(n_samples / batch_size)  
  98.     print('total batchs: %d' % total_batch)  
  99.     init = tf.global_variables_initializer()  
  100.     sess.run(init)  
  101.     for epoch in range(train_epochs):  
  102.         for batch_index in range(total_batch):  
  103.             batch_x, _ = mnist.train.next_batch(batch_size)  
  104.             noise_x = batch_x + noise_factor * np.random.randn(*batch_x.shape)  
  105.             noise_x = np.clip(noise_x, 0.1.)  
  106.             _, train_loss = sess.run([optimizer, loss], feed_dict={input_x: noise_x, input_raw: batch_x})  
  107.             print('epoch: %04d\tbatch: %04d\ttrain loss: %.9f' % (epoch + 1, batch_index + 1, train_loss))  
  108.   
  109.     ## 训练结束后,用测试集测试,并保存加噪图像、去噪图像  
  110.     n_test_samples = int(mnist.test.num_examples)  
  111.     test_total_batch = int(n_test_samples / batch_size)  
  112.     for i in range(test_total_batch):  
  113.         batch_test_x, _ = mnist.test.next_batch(batch_size)  
  114.         noise_test_x = batch_test_x + noise_factor * np.random.randn(*batch_test_x.shape)  
  115.         noise_test_x = np.clip(noise_test_x, 0.1.)  
  116.         test_loss, pred_result = sess.run([loss, conv_final], feed_dict={input_x: noise_test_x, input_raw: batch_test_x})  
  117.         print('test batch index: %d\ttest loss: %.9f' % (i + 1, test_loss))  
  118.         for index in range(batch_size):  
  119.             array = np.reshape(pred_result[index], newshape=[INPUT_HEIGHT, INPUT_WIDTH])  
  120.             array = array * 255  
  121.             image = Image.fromarray(array)  
  122.             if image.mode != 'L':  
  123.                 image = image.convert('L')  
  124.             image.save('./pred/' + str(i * batch_size + index) + '.png')  
  125.             array_raw = np.reshape(noise_test_x[index], newshape=[INPUT_HEIGHT, INPUT_WIDTH])  
  126.             array_raw = array_raw * 255  
  127.             image_raw = Image.fromarray(array_raw)  
  128.             if image_raw.mode != 'L':  
  129.                 image_raw = image_raw.convert('L')  
  130.             image_raw.save('./pred/' + str(i * batch_size + index) + '_raw.png')  
  131.         #break  

去噪效果:

你可能感兴趣的:(tensorflow)