错误修正:tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor

在训练GAN网络时,提示以下报错:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 7500 values, but the requested shape requires a multiple of 27

网络部分代码如下:

    def G(self):
        with tf.name_scope("Gen") as sc:
            output1 = self.fully_con(self.y, 25, sc + "_1")
            output2 = self.fully_con(output1, 100, sc + "_2")
            output3 = self.fully_con(output2, 500, sc + "_3")
            output4 = self.fully_con(output3, 100, sc + "_4")
            output5 = self.fully_con(output4, self.shape_2[-1] * 25, sc + "_5")
            return tf.reshape(output5, [-1, PATCH_SIZE, PATCH_SIZE, self.shape_2[-1]])

    def A(self):
        with tf.name_scope("App") as sc:
            output1 = self.fully_con(self.x, 25, sc + "_1")
            output2 = self.fully_con(output1, 100, sc + "_2")
            output3 = self.fully_con(output2, 500, sc + "_3")
            output4 = self.fully_con(output3, 100, sc + "_4")
            output5 = self.fully_con(output4, self.shape_2[-1] * 25, sc + "_5")
            return tf.reshape(output5, [-1, PATCH_SIZE, PATCH_SIZE, self.shape_2[-1]])

    def D(self):
        with tf.name_scope("Dis") as sc:
            self.d = tf.concat([self.x, self.G], 0)
            output1 = self.fully_con(self.d, 25, sc + "_1")
            output2 = self.fully_con(output1, 100, sc + "_2")
            output3 = self.fully_con(output2, 200, sc + "_3")
            output4 = self.fully_con(output3, 50, sc + "_4")
            output5 = self.fully_con(output4, 1, sc + "_5", tf.nn.sigmoid)
            self.p1, self.p2 = tf.split(output5, 2, 0)

分析是由于输入reshape的Tensor维度与所需维度不一致而导致的,因此检查reshape的输入与输出维度,逐步往前推进,找到问题根源。
我这里是因为之前定义了PATCH_SIZE=3,而G、A、D的输入维度均为25导致的,应该改为PATCH_SIZE*PATCH_SIZE
改正以后就没有报错了,至此,解决了以上问题。

你可能感兴趣的:(问题解决,python,tensorflow,人工智能,矩阵,深度学习)