ValueError: Cannot feed value of shape (256, 9, 129) for Tensor 'model/inputs:0', which has shape

问题描述

  • speech-enhancement工程 [https://github.com/jtkim-kaist/Speech-enhancement]
  • 使用环境
    - tensorflow 1.7
    - librosa
    - matlab2017a
    - tensorboardX

当我运行main.py时,发生以下错误。此时,使用的模型是fnn。 后来,当我在trnmodle.py中检查fnn模型时,我发现了错误的原因。 但是,我没有修改datareader.py中的参数或命令。当我使用lstm模型时发生了同样的错误。

ValueError: Cannot feed value of shape (256, 9, 129, 1) for Tensor ‘model/inputs:0’, which has shape ‘(?, 129, 1, 1)’

调用的fnn模型的为 config.py ,如下所示:

	import numpy as np
	import os

	mode = "fnn"  # fnn, fcn, lstm, sfnn, irm, tsn

	if mode == 'lstm':
		 time_width = int(16)
	else:
		time_width = int(9)

	fs = float(8000)
	win_size = int(0.025 * fs)  # The number of samples in window
	win_step = int(0.010 * fs)
	# nfft = np.int(2 ** (np.floor(np.log2(win_size) + 1)))
	nfft = np.int(256)

	freq_size = int(nfft/2+1)

	lr = 0.0001
	lrDecayRate = .99  # 0.99
	lrDecayFreq = 2000

	keep_prob = 0.9
	global_std = 1.18

	device = '/gpu:0'

	# logs_dir = os.path.abspath('../logs')

	dist_num = int(4)

	max_epoch = int(1e6)

	batch_size = int(256)

	test_batch_size = 128

	val_step = int(500)
	summary_step = int(1000)  # 3000
	summary_fnum = int(5)

	parallel = False

此时trnmodel.py里面的部分代码如下:

def inference(self, inputs):
    if config.mode is "fcn":
        fm = utils.conv_with_bn(inputs, out_channels=12, filter_size=[config.time_width, 13],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_1")

        fm = utils.conv_with_bn(fm, out_channels=16, filter_size=[config.time_width, 11],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_2")

        fm = utils.conv_with_bn(fm, out_channels=20, filter_size=[config.time_width, 9],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_3")

        fm_skip = utils.conv_with_bn(fm, out_channels=24, filter_size=[config.time_width, 7],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_4")

        fm = utils.conv_with_bn(fm_skip, out_channels=32, filter_size=[config.time_width, 7],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_5")

        fm = utils.conv_with_bn(fm, out_channels=24, filter_size=[config.time_width, 7],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_6") + fm_skip

        fm = utils.conv_with_bn(fm, out_channels=20, filter_size=[config.time_width, 9],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_7")

        fm = utils.conv_with_bn(fm, out_channels=16, filter_size=[config.time_width, 11],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_8")

        fm = utils.conv_with_bn(fm, out_channels=12, filter_size=[config.time_width, 13],
                                stride=1, act='relu', is_training=self._is_training,
                                padding="SAME", name="conv_9")

        fm = utils.conv_with_bn(fm, out_channels=1, filter_size=[config.time_width, config.freq_size],
                                stride=1, act='linear', is_training=self._is_training,
                                padding="SAME", name="conv_10")  # (batch_size, 1, config.freq_size, 1)

        # fm = utils.conv_with_bn(fm, out_channels=1, filter_size=[config.time_width, 1],
        #                         stride=1, act='linear', is_training=self._is_training,
        #                         padding="VALID", name="conv_last")
        fm = tf.squeeze(fm, [1, 3])

        return fm

    elif config.mode is "fnn":

        keep_prob = self.keep_prob

        # inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
        # inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)
        #
        # h1 = tf.nn.relu(utils.batch_norm_affine_transform(inputs, 2048, name='hidden_1',
        #                                                         is_training=self._is_training))
        # # h1 = tf.nn.dropout(h1, keep_prob=keep_prob)
        #
        # h2 = tf.nn.relu(utils.batch_norm_affine_transform(h1, 2048, name='hidden_2',
        #                                                         is_training=self._is_training))
        # # h2 = tf.nn.dropout(h2, keep_prob=keep_prob)
        #
        # # h3 = tf.nn.relu(utils.batch_norm_affine_transform(h2, 2048, name='hidden_3',
        # #                                                         is_training=self._is_training))
        # # h3 = tf.nn.dropout(h3, keep_prob=keep_prob)
        #
        # fm = utils.affine_transform(h2, config.freq_size, name='logits')

        inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
        inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)

        h1 = tf.nn.selu(utils.affine_transform(inputs, 2048, name='hidden_1'))
        h1 = tf.nn.dropout(h1, keep_prob=keep_prob)

        h2 = tf.nn.selu(utils.affine_transform(h1, 2048, name='hidden_2'))
        h2 = tf.nn.dropout(h2, keep_prob=keep_prob)

        h3 = tf.nn.selu(utils.affine_transform(h2, 2048, name='hidden_3'))
        h3 = tf.nn.dropout(h3, keep_prob=keep_prob)

        fm = utils.affine_transform(h3, config.freq_size, name='logits')

        return fm

经过分析错误的原因在于:

	inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))

这里使用的 tf.reshape() ,其中 config.time_width=9 ,config.freq_size=129;
而输入的数据集所采样得到的数据的大小无被 9*129=1161 所整除

方法

目前正在检查输入数据的状况。

你可能感兴趣的:(ValueError: Cannot feed value of shape (256, 9, 129) for Tensor 'model/inputs:0', which has shape)