def conv2d(input, # 张量输入
filter, # 卷积核参数
strides, # 步长参数
padding, # 卷积方式
use_cudnn_on_gpu=None, # 是否是gpu加速
data_format=None, # 数据格式,与步长参数配合,决定移动方式
name=None): # 名字,用于tensorboard图形显示时使用
认真看过后,其实没有那么头大,核心只有2点, 也是你设计一个卷积必须要有的两点
1、 卷积核
你想象一下,我们要定义卷积核,需要定义哪几个东西才能将他确定下来?卷积核长度,一定要有吧!卷积核宽度也是必不可少的吧!还有一个容易忽略的东西,通道数量。图片有单通道的灰度图,有3通道的RGB图片,所以这个也一定要有。
所以通常对于filter, 定义为一个列表或4-D tensor [filter_height, filter_width, in_channels, out_channels]]
2、步长
卷积核相对图片滑动,然后进行卷积,提取特征。细想,问题有两个,怎么相对? 滑动多少? 那么步长(strides)来帮你解决问题。
步长(strides)就是移动方式,图片(张量以下我们以图片,代替,方便理解,即二维数据)数据有通道数,有长宽,卷积核是先按宽度方向按指定步长移动,还是按高度方向?,还是按通道方向?这个得有个说法。我们帖一个数据结构图
移动方式一: 第一个元素是000,第二个元素是沿着w方向的,即001,这样下去002 003,再接着呢就是沿着H方向,即004 005 006 007…这样到09后,沿C方向,轮到了020,之后021 022 …一直到319,然后再沿N方向。————这种方式叫:NCHW (注意顺序,N是n,是数量,C是channel, H是height, W是weight)
移动方式二:第一个元素是000,第二个沿C方向,即020,040, 060…一直到300,之后沿W方向,001 021 041 061…301…到了303后,沿H方向,即004 024 .。。304.。最后到了319,变成N方向,320,340…————这种方式叫:NHWC
所以dataformat参数的取值有两种,NCHW ,NHWC,默认是NHWC。
那么,与此dataformat配合的strides如何取值呢?
当dataformat是默认的NHWC时,strides=[batch, height, width, channels]
当dataformat是NCHW时,strides=[batch, channels, height, width]
那么,是不是有个疑问,为什么两个顺序不一样?这不科学嘛!
其实你认真看,strides的顺序是dataformat数据顺序是一致的,你看NHWC -> [b(n), h, w, c]->[batch,height,wight, channel]。通常我们都没有更改默认设置,所以strides是[batch, height, width, channels],所以当你找到定义函数文档有这样的说明
Must have strides[0] = strides[3] = 1
. For the most common case of the same
horizontal and vertices strides, strides = [1, stride, stride, 1]
.
还有,不光strides如此,连input也是如此,请看input和data_format的说明
Args:
input: A Tensor
. Must be one of the following types: half
, float32
.
A 4-D tensor. The dimension order is interpreted according to the value
of data_format
, see below for details.
data_format: An optional string
from: "NHWC", "NCHW"
. Defaults to "NHWC"
.
Specify the data format of the input and output data. With the
default format “NHWC”, the data is stored in the order of:
[batch, height, width, channels].
Alternatively, the format could be “NCHW”, the data storage order of:
[batch, channels, height, width].
到此,相信大家已经十分清楚tf.nn.conved()参数的来龙去脉了,如有疑问,欢迎指正。