由于上上篇博客写了使用tensorflow实现2D小波变化dwt和小波逆变换idwt,但是实现的方法在速度上和资源占用上实在堪忧,特别是在channel比较大的情况下。因此本人对于上次的代码进行了优化。
上述的两种操作之所以能够节省计算资源,提升速度。原因在于,tensorflow会在反向传播的时候保存下来每一个tensor操作的结果。例如,for循环64个tf.concat,那么tensorflow就会保存64个concat的反向梯度图,分别为tf.concat_1…tf.concat_64(表述可能不严谨),保存的这些结果都会占用大量的计算资源,而这些对于计算并不是必要的。因此要节省计算资源,就是要使用尽量少的tensor操作来实现功能。tensorflow提供的tf.slice命令就可以完全替代原来循环的tf.concat结构,而反向传播中只占用了原来循环一次的资源。同样的道理循环的卷积也是如此,虽然3D卷积也是消耗资源的,但是,相比之下还是优于循环结构的。
另外:此次的代码和上次还有一个小的区别,调整了卷积核的尺寸,实现DWT的同时加速。原来默认的基为db3,卷积核的尺寸为6,调整后的默认基为haar,卷积核尺寸为2。读者可以根据自己的需要给定基。
# -*- coding: utf-8 -*-
# @Author : Cmy
# @time : 2018/12/5 20:37
# @File : tf_dwt_3d_v2.py
# @Software : PyCharm
import numpy as np
import tensorflow as tf
from PIL import Image
import pywt
import time
import matplotlib.pyplot as plt
# C is channel # just suit for J=1
def tf_dwt(yl, wave='haar'):
w = pywt.Wavelet(wave)
ll = np.outer(w.dec_lo, w.dec_lo)
lh = np.outer(w.dec_hi, w.dec_lo)
hl = np.outer(w.dec_lo, w.dec_hi)
hh = np.outer(w.dec_hi, w.dec_hi)
d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4))
d_temp[::-1, ::-1, 0, 0] = ll
d_temp[::-1, ::-1, 0, 1] = lh
d_temp[::-1, ::-1, 0, 2] = hl
d_temp[::-1, ::-1, 0, 3] = hh
filts = d_temp.astype('float32')
filts = filts[None, :, :, :, :]
filter = tf.convert_to_tensor(filts)
sz = 2 * (len(w.dec_lo) // 2 - 1)
with tf.variable_scope('DWT'):
### Pad odd length images
# if in_size[0] % 2 == 1 and tf.shape(yl)[1] % 2 == 1:
# yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz + 1], [0, 0]]), mode='reflect')
# elif in_size[0] % 2 == 1:
# yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz], [0, 0]]), mode='reflect')
# elif in_size[1] % 2 == 1:
# yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz + 1], [0, 0]]), mode='reflect')
# else:
yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz], [0, 0]]), mode='reflect')
y = tf.expand_dims(yl, 1)
inputs = tf.split(y, [1]*int(y.shape.dims[4]), 4)
inputs = tf.concat([x for x in inputs], 1)
outputs_3d = tf.nn.conv3d(inputs, filter, padding='VALID', strides=[1, 1, 2, 2, 1])
outputs = tf.split(outputs_3d, [1] * int(outputs_3d.shape.dims[1]), 1)
outputs = tf.concat([x for x in outputs], 4)
outputs = tf.reshape(outputs, (tf.shape(outputs)[0], tf.shape(outputs)[2],
tf.shape(outputs)[3], tf.shape(outputs)[4]))
return outputs
def tf_idwt(y, wave='haar'):
w = pywt.Wavelet(wave)
ll = np.outer(w.rec_lo, w.rec_lo)
lh = np.outer(w.rec_hi, w.rec_lo)
hl = np.outer(w.rec_lo, w.rec_hi)
hh = np.outer(w.rec_hi, w.rec_hi)
d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4))
d_temp[:, :, 0, 0] = ll
d_temp[:, :, 0, 1] = lh
d_temp[:, :, 0, 2] = hl
d_temp[:, :, 0, 3] = hh
filts = d_temp.astype('float32')
filts = filts[None, :, :, :, :]
filter = tf.convert_to_tensor(filts)
s = 2 * (len(w.dec_lo) // 2 - 1)
out_size = tf.shape(y)[1]
with tf.variable_scope('IWT'):
y = tf.expand_dims(y, 1)
inputs = tf.split(y, [4] * int(int(y.shape.dims[4])/4), 4)
inputs = tf.concat([x for x in inputs], 1)
outputs_3d = tf.nn.conv3d_transpose(inputs, filter, output_shape=[tf.shape(y)[0], tf.shape(inputs)[1],
2*(out_size-1)+np.shape(ll)[0],
2*(out_size-1)+np.shape(ll)[0], 1],
padding='VALID', strides=[1, 1, 2, 2, 1])
outputs = tf.split(outputs_3d, [1] * int(int(y.shape.dims[4])/4), 1)
outputs = tf.concat([x for x in outputs], 4)
outputs = tf.reshape(outputs, (tf.shape(outputs)[0], tf.shape(outputs)[2],
tf.shape(outputs)[3], tf.shape(outputs)[4]))
outputs = outputs[:, s: 2 * (out_size - 1) + np.shape(ll)[0] - s, s: 2 * (out_size - 1) + np.shape(ll)[0] - s,
:]
return outputs
if __name__ == '__dwt__':
# load images
a = Image.open('12074.jpg')
X_n = np.array(a).astype('float32')
X_n = X_n / 255
X_n = X_n[0:256, 0:256, :]
X_t = np.zeros((1, 256, 256, 3), dtype='float32')
X_t[0, :, :, :] = X_n[:, :, :]
X_tf = tf.convert_to_tensor(X_t)
# convert to tensor
sess = tf.Session()
inputs = tf.placeholder(tf.float32, [None, None, None, 3], name='inputs')
outputs_in = tf.placeholder(tf.float32, [None, None, None, 12], name='outputs')
outputs = tf_dwt(inputs)
outputs_mex = tf_idwt(outputs_in)
sess.run(tf.global_variables_initializer())
time_start = time.time()
outputs_dwt = sess.run(outputs, feed_dict={inputs: X_t})
outputs_mex = sess.run(outputs_mex, feed_dict={outputs_in: outputs_dwt})
time_end = time.time()
print('totally cost', time_end - time_start)
# show the decomposition images
plt.figure()
plt.imshow(outputs_dwt[0, :, :, 0], cmap='gray')
plt.figure()
plt.imshow(outputs_mex[0, :, :, 0], cmap='gray')
# # # pywt
cA, (cH, cV, cD) = pywt.dwt2(X_n[:, :, 0], 'haar')
# show the pywt
plt.figure()
plt.imshow(np.abs(cH-outputs_dwt[0, :, :, 1]), cmap='gray')
plt.figure()
plt.imshow(np.abs(X_n[:, :, 1] - outputs_mex[0, :, :, 1]), cmap='gray')
plt.show()