先来看一下API的说明:
tf.nn.max_pool(value, ksize, strides, padding, data_format='NHWC', name=None)
Performs the max pooling on the input.
value
: A 4-D Tensor
with shape [batch, height, width, channels]
and type tf.float32
.ksize
: A list of ints that has length >= 4. The size of the window for each dimension of the input tensor.strides
: A list of ints that has length >= 4. The stride of the sliding window for each dimension of the input tensor.padding
: A string, either 'VALID'
or 'SAME'
. The padding algorithm. See the comment heredata_format
: A string. 'NHWC' and 'NCHW' are supported.name
: Optional name for the operation.A Tensor
with type tf.float32
. The max pooled output tensor.
第一个参数value:需要池化的输入,一般池化层接在卷积层后面,所以输入通常是feature map,依然是[batch, height, width, channels]这样的shape
第二个参数ksize:池化窗口的大小,取一个四维向量,一般是[1, height, width, 1],因为我们不想在
batch和
channels
上做池化,所以这两个维度设为了1
第三个参数strides:和卷积类似,窗口在每一个维度上滑动的步长,一般也是[1, stride,
stride
, 1]
第四个参数padding:和卷积类似,可以取'VALID' 或者'SAME'
返回一个Tensor,类型不变,shape仍然是[batch, height, width, channels]
这种形式
例子:
#python 3.5.3
#2017-03-09 蔡军生 http://blog.csdn.net/caimouse
#
import tensorflow as tf
import numpy as np
from PIL import Image
fpath = './test2.jpg'
jpg = tf.read_file(fpath)
img_arr = tf.image.decode_jpeg(jpg, channels=3)
img_4d = tf.cast(tf.reshape(img_arr, [1, 936, 764, 3]), tf.float32)
pool = tf.nn.max_pool(img_4d, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
with tf.Session() as sess:
img, pool = sess.run([img_arr, pool])
print(img.shape)
print(pool.shape)
Image.fromarray(np.uint8(pool.reshape(pool.shape[1:4]))).save('./maxpool2.jpg')
输入图片和输出结果:
池化之后图片: