tensorflow中二维卷积函数tf.nn.conv2d()定义:
def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", dilations=[1, 1, 1, 1], name=None)
第一个参数input:指需要做卷积的输入图像,它要求是一个Tensor,具有[batch, in_height, in_width, in_channels]这样的shape,具体含义是[训练时一个batch的图片数量, 图片高度, 图片宽度, 图像通道数],注意这是一个4维的Tensor,要求类型为float32和float64其中之一
第二个参数filter:相当于CNN中的卷积核,它要求是一个Tensor,具有[filter_height, filter_width, in_channels, out_channels]这样的shape,具体含义是[卷积核的高度,卷积核的宽度,图像通道数,卷积核个数],要求类型与参数input相同,filter的通道数要求与input的in_channels一致,有一个地方需要注意,第三维in_channels,就是参数input的第四维
第三个参数strides:卷积时在图像每一维的步长,这是一个一维的向量,长度4,strides[0]=strides[3]=1 ,例如strides = [1,2,2,1]
第四个参数padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式,关于两者的具体区别可以参考: https://blog.csdn.net/dcrmg/article/details/82317096
第五个参数:use_cudnn_on_gpu:bool类型,是否使用cudnn加速,默认为true
第六个参数data_format:表示输入的tensor的格式,默认是data_format="NHWC",4个字母的含义如下:
所以data_format="NHWC",则tensor的格式是[batch, in_height, in_width, in_channels];
相对的还有data_format="NCHW",则tensor的格式是[batch, in_channels, in_height, in_width]
第七个参数dilations:卷积扩张因子,默认值是[1, 1, 1, 1],如果设置k大于1,则卷积的时候会跳过k-1个元素卷积,相当于扩张了卷积面积?
第八个参数name:操作的名称
结果返回一个Tensor,这个输出,就是特征图feature map
tf.nn.conv2d()函数使用示例:
# coding: utf-8
import tensorflow as tf
# 定义输入图像size
img_size = 256
# 卷积核大小
kernel_size = 7
# 卷积步长
stride_size = 1
# 读入图像文件
image_value = tf.read_file('./661.jpg')
# 图像编码
img = tf.image.decode_jpeg(image_value, channels=3)
# 格式转换
img = tf.to_float(img, name='ToFloat')
# 调整图像大小到定义尺寸
img = tf.image.resize_images(img, [img_size,img_size],method=0)
# 第一个参数1是输入图片数量,最后一个3个RGB3个维度
batch_shape = (1,img_size,img_size,3)
# 维度转换,为卷积做准备(卷积的输入特征图的rank必须是4)
img = tf.reshape(img,batch_shape)
# 卷积核大小5×5,深度是3(跟RGB3个维度保持一致),特征图(卷积核)数量是7
filter = tf.Variable(tf.random_normal([kernel_size,kernel_size,3,7]))
# 步长1
strides_shape=[1,stride_size,stride_size,1]
# 定义卷积操作
op_conv2d = tf.nn.conv2d(img, filter, strides_shape, padding='SAME')
# 创建运行sess
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
out_img= sess.run(op_conv2d)
print('输入图像维度: {}'.format(img.shape))
print('输出图像维度: {}'.format(out_img.shape))
# 输入图像维度: (1, 256, 256, 3)
# 输出特征图维度: (1, 256, 256, 7)