问题背景:
假如我现在有一个矩阵为image,卷积核为weight,卷积时不填充,则卷积后的结果为conv
image =
[[[[ 1. 2. 3. 4. 5.]
[ 6. 7. 8. 9. 10.]
[11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20.]
[21. 22. 23. 24. 25.]]
[[26. 27. 28. 29. 30.]
[31. 32. 33. 34. 35.]
[36. 37. 38. 39. 40.]
[41. 42. 43. 44. 45.]
[46. 47. 48. 49. 50.]]]]
weight =
[[[[1 0 0]
[0 1 0]
[1 0 0]]
[[0 1 0]
[0 0 1]
[0 0 0]]]]
conv=
[[[[ 79. 84. 89.]
[104. 109. 114.]
[129. 134. 139.]]]]
那么现在问题来了,初始化好imgae,weight这两个变量后,这么用tf.nn.conv2d
来计算得到这个结果呢?为了说清楚这个问题,先来做两个铺垫。
我们知道在tensorflow
中,表示一张图片需要用4个维度,即[batch_size,width,high,channel]
。但是呢,如果这样表示的话输出来的结果对于我们人来说却一点都不直观,我们肉眼最直观的表示方式是[batch_size,channel,width,high]
(例如cifar数据集的表示方式就是后者),举个例子:
假如我现在有一张5乘5的2通道图片,即上面的image,如下:
image_in_man = np.linspace(1, 50, 50).reshape(1, 2, 5, 5)
[[[[ 1. 2. 3. 4. 5.]
[ 6. 7. 8. 9. 10.]
[11. 12. 13. 14. 15.]
[16. 17. 18. 19. 20.]
[21. 22. 23. 24. 25.]]
[[26. 27. 28. 29. 30.]
[31. 32. 33. 34. 35.]
[36. 37. 38. 39. 40.]
[41. 42. 43. 44. 45.]
[46. 47. 48. 49. 50.]]]]
它的形状是[batch_size,channel,width,high]
,即[1,2,5,5]
我称为人类视角;可我们在使用tf.nn.con2d
时必须将其转换为[batch_size,width,high,channel]
,即[1,5,5,2]
即tf视角,但这样的结果对于我们来说却不直观:
很明显两者的形状都是[1,5,5,2]
,可是结果却大相径庭。当然,转换这种视角用的是transpose
.
做了以上的铺垫,那么我们接下来看看最初的问题怎来实现:
import tensorflow as tf
import numpy as np
image_in_man = np.linspace(1, 50, 50).reshape(1, 2, 5, 5)
image_in_tf = image_in_man.transpose(0, 2, 3, 1)
#
weight_in_man = np.array([1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]).reshape(1, 2, 3, 3)
weight_in_tf = weight_in_man.transpose(2, 3, 1, 0)
print('image in man:')
print(image_in_man)
# print(image_in_tf)
print('weight in man:')
print(weight_in_man)
# #
x = tf.placeholder(dtype=tf.float32, shape=[1, 5, 5, 2], name='x')
w = tf.placeholder(dtype=tf.float32, shape=[3, 3, 2, 1], name='w')
conv = tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='VALID')
with tf.Session() as sess:
r_in_tf = sess.run(conv, feed_dict={x: image_in_tf, w: weight_in_tf})
r_in_man = r_in_tf.transpose(0, 3, 1, 2)
print(r_in_man)
结果就是最上面所给出的问题。不过在实际问题中我们不需要怎么去关注权重,因为权重都是随机初始化的;但是当我们输入特定类图片数据的时候,一定要注意是用transpose
将人类视角转换为tf视角。不然虽然reshape
后也满足tf.nn.conv2d
的要求,可以进行卷积操作,但是这样计算出来的结果却是错的。