原文链接: TensorFlow 池化操作
上一篇: TensorFlow 卷积操作模拟sobel算子提取图像轮廓
下一篇: cifar 数据集下载使用
函数原型
def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
value:一般池化层接在卷积层后面,输入通常是feature map, 【batch, height, width,channels】
ksize: 池化窗口大小,四维向量,一般【1,height,width,1】,因为一般情况下不对batch和channel上做池化
strides:窗口在每一个维度上滑动的步长,一般【1,stride,stride,1】
padding:和卷积参数含义和计算类似,取值VALID和SAME
范湖一个Tensor,类型不变,shape依然是【batch,height,width,channels】形式
均值池化时补0的情况,计算和卷积一样
pooling2 = tf.nn.avg_pool(img, [1, 4, 4, 1], [1, 1, 1, 1], padding='SAME')
reslut2:
[[[[10. 11.]
[11. 12.]
[12. 13.]
[13. 14.]]
[[14. 15.]
[15. 16.]
[16. 17.]
[17. 18.]]
[[18. 19.]
[19. 20.]
[20. 21.]
[21. 22.]]
[[22. 23.]
[23. 24.]
[24. 25.]
[25. 26.]]]]
最大池化和均值池化的比较,均值不会计算padding补的0
import tensorflow as tf
import numpy as np
img = np.arange(0, 32, dtype=np.float32)
img = tf.reshape(img, [1, 4, 4, 2])
pooling = tf.nn.max_pool(img, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID')
pooling1 = tf.nn.max_pool(img, [1, 2, 2, 1], [1, 1, 1, 1], padding='VALID')
pooling2 = tf.nn.avg_pool(img, [1, 4, 4, 1], [1, 1, 1, 1], padding='SAME')
pooling3 = tf.nn.avg_pool(img, [1, 4, 4, 1], [1, 4, 4, 1], padding='SAME')
nt_hpool2_flat = tf.reshape(tf.transpose(img), [-1, 16])
pooling4 = tf.reduce_mean(nt_hpool2_flat, 1) # 1对行求均值(1表示轴是列) 0 对列求均值
with tf.Session() as sess:
print("image:")
image = sess.run(img)
print(image)
result = sess.run(pooling)
print("reslut:\n", result)
result = sess.run(pooling1)
print("reslut1:\n", result)
result = sess.run(pooling2)
print("reslut2:\n", result)
result = sess.run(pooling3)
print("reslut3:\n", result)
flat, result = sess.run([nt_hpool2_flat, pooling4])
print("reslut4:\n", result)
print("flat:\n", flat)
image:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]]
[[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]]
[[16. 17.]
[18. 19.]
[20. 21.]
[22. 23.]]
[[24. 25.]
[26. 27.]
[28. 29.]
[30. 31.]]]]
reslut:
[[[[10. 11.]
[14. 15.]]
[[26. 27.]
[30. 31.]]]]
reslut1:
[[[[10. 11.]
[12. 13.]
[14. 15.]]
[[18. 19.]
[20. 21.]
[22. 23.]]
[[26. 27.]
[28. 29.]
[30. 31.]]]]
reslut2:
[[[[10. 11.]
[11. 12.]
[12. 13.]
[13. 14.]]
[[14. 15.]
[15. 16.]
[16. 17.]
[17. 18.]]
[[18. 19.]
[19. 20.]
[20. 21.]
[21. 22.]]
[[22. 23.]
[23. 24.]
[24. 25.]
[25. 26.]]]]
reslut3:
[[[[15. 16.]]]]
reslut4:
[15. 16.]
flat:
[[ 0. 8. 16. 24. 2. 10. 18. 26. 4. 12. 20. 28. 6. 14. 22. 30.]
[ 1. 9. 17. 25. 3. 11. 19. 27. 5. 13. 21. 29. 7. 15. 23. 31.]]