理解tf.nn.conv2d和tf.nn.conv2d_transpose

TensorFlow conv2d API

conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)

input: 卷积输入,Tensor(tf.Constant,tf.Variable,tf.placeholder),[batch, in_height, in_width, in_channel]
filter: 卷积核,Tensor,[filter_height, filter_width, in_channel, out_channel]
strides: 卷积核在各个维度移动的步长,a 4-D list,[stride_batch, stride_height, stride_width, stride_channel]
padding: 对输入input的填充方法,'SAME' 和 'VALID'两种

TensorFlow Convolution API

padding为'SAME'时,
1、输出Tensor大小: [batch, ceil(in_height / filter_height), ceil(in_width / filter_width), out_channel]
2、上下左右padding大小:

if (in_height % strides[1] == 0):
  pad_along_height = max(filter_height - strides[1], 0)
else:
  pad_along_height = max(filter_height - (in_height % strides[1]), 0)
if (in_width % strides[2] == 0):
  pad_along_width = max(filter_width - strides[2], 0)
else:
  pad_along_width = max(filter_width - (in_width % strides[2]), 0)

pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left

padding为'VALID'时,
1、输出Tensor大小: [batch, ceil(in_height - filter[1]) / strides[1] + 1, ceil(in_width - filter[2]) / strides[2] + 1, out_channel]
2、没有padding

TensorFlow conv2d_transpose API

conv2d_transpose(input, filter, output_shape, strides, padding='SAME', data_format='NHWC', name='None')

input: 转置卷积, a Tensor, [batch, in_height, in_width, in_channel]
filter: 卷积核, a Tensor, [filter_height, filter_width,  out_channel, in_channel ]
output_shape: a 1-D Tensor or a N-D shape list. for example, [batch, out_height, out_width, out_channel]
strides: 步长, a 4-D list, [stride_batch, stride_height, stride_width, stride_channel]
padding: 对输入input的填充方法,'SAME' 和 'VALID'两种

Tensorflow Convolution API

1、conv2d_transpose会根据output_shape和padding计算一个shape,然后和input的shape相比较,如果不同会报错。
2、做转置卷积时,通常input的shape比output_shape要小,因此TensorFlow先把input填充成output_shape大小,再按照padding参数进行填充
stride==1时,外围填充;

理解tf.nn.conv2d和tf.nn.conv2d_transpose_第1张图片
stride == 1

stride>1时,间隙填充

理解tf.nn.conv2d和tf.nn.conv2d_transpose_第2张图片
stride > 1

几个问题:
1、为什么transposed convolution要指定output_shape?

你可能感兴趣的:(理解tf.nn.conv2d和tf.nn.conv2d_transpose)