tensorflow进行反卷积操作时出现错误 IndexError: list index out of range

出现错误的代码片段为:

#以步长为2的滤波器对incoming特征图进行反卷积
def upsample(incoming, filter_size, stride=2, name='upsample'):
    x = incoming
    input_shape = x.get_shape().as_list()
    strides = [1, stride, stride, 1]
    output_shape = (input_shape[0],
                    input_shape[1] * strides[1],
                    input_shape[2] * strides[2],
                    input_shape[3])                  #输出图变成原图的两倍
    filter_shape = (filter_size, filter_size, input_shape[3] , input_shape[3])  #滤波器尺寸
    with tf.name_scope(name) as scope:
        resized = tf.nn.conv2d_transpose(x,  filter_shape, output_shape, strides, padding="SAME")
    return resized

出现错误的原因:

代码其他部分都没有问题,但在进行 tf.nn.conv2d_transpose函数调用时总是出现错误 IndexError: list index out of range,错误解答说是有两种情况:
(1) list[index] index超出范围
(2)list是一个空的 没有一个元素,进行list[0]就会出现该错误

再查看conv2d_transpose卷积函数

conv2d_transpose(value, filter, output_shape, strides, padding=“SAME”, data_format=“NHWC”, name=None)

value:指需要做反卷积的输入图像,它要求是一个Tensor
filter:卷积核,Tensor,shape为[filter_height, filter_width, out_channels, in_channels],具体含义是[卷积核的高度,卷积核的宽度,卷积核个数,图像通道数]
output_shape:反卷积操作输出的shape
strides:反卷积时在图像每一维的步长
padding:string类型的量,只能是"SAME","VALID"其中之一
data_format:string类型的量,'NHWC’和’NCHW’其中之一,这是tensorflow新版本中新加的参数,它说明了value参数的数据格式。'NHWC’指tensorflow标准的数据格式[batch, height, width, in_channels],‘NCHW’指Theano的数据格式,[batch, in_channels,height, width],当然默认值是’NHWC’

在参考其他代码的反卷积操作后,发现tf.nn.conv2d_transpose函数出现错误的原因应该是第二种情况。只把filter的大小尺寸filter_shape传入函数中,但代码并没有给出该filter的卷积核的值,导致list为空,因此需先对filter初始化。

修改后的代码片段为:

#定义函数get_deconv_filter初始化反卷积滤波器
def get_deconv_filter(f_shape):
   width = f_shape[0]
   heigh = f_shape[0]
   f = ceil(width/2.0)
   c = (2 * f - 1 - f % 2) / (2.0 * f)
   bilinear = np.zeros([f_shape[0], f_shape[1]])
   for x in range(width):
       for y in range(heigh):
           value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
           bilinear[x, y] = value
   weights = np.zeros(f_shape)
   for i in range(f_shape[2]):
       weights[:, :, i, i] = bilinear

   init = tf.constant_initializer(value=weights,
                                  dtype=tf.float32)
   return tf.get_variable(name="up_filter", initializer=init,
                          shape=weights.shape)

再修改上述错误代码为:

def upsample(incoming, filter_size, stride=2, name='upsample'):#, filter_size, stride=2,
    x = incoming
    print(x)

    input_shape = x.get_shape().as_list()
    strides = [1, stride, stride, 1]

    output_shape = (input_shape[0],
                    input_shape[1] * strides[1],
                    input_shape[2] * strides[2],
                    input_shape[3])                  #输出图变成原图的两倍
    print(output_shape)
    filter_shape = (filter_size, filter_size, input_shape[3] , input_shape[3])  #滤波器尺寸
    weights = get_deconv_filter(filter_shape)
    print(weights)
    with tf.name_scope(name) as scope:
        resized = tf.nn.conv2d_transpose(x, weights, output_shape, strides, padding="SAME")#(x, filter_shape, output_shape, strides, padding="SAME")
 
    return resized

之后再运行就没有报错了

你可能感兴趣的:(error)