做中国大学MOOC中“人工智能实践:TensorFlow笔记”的fashion数据集卷积神经网络练习时报如下错误:
ValueError: Input 0 of layer conv2d is incompatible with the layer:
expected ndim=4, found ndim=3. Full shape received: [None, 28, 28]
但是一模一样的模型,只是把数据集换成cifar10数据集,就没有错误。fashion数据集和cifar10数据集导入的代码分别如下:
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# cifar10 = tf.keras.datasets.cifar10
# (x_train, y_train), (x_test, y_test) = cifar10.load_data()
# x_train, x_test = x_train / 255.0, x_test / 255.0
这是为什么呢?先对比看看着两个数据集的区别:
看下来两个数据集的主要区别在于一个是单通道的灰度值图片,一个是三通道的彩色图片。但是这个区别和报的错误有什么关系吗?作为小白,我们再认真看看报的错误,直白地理解下来,大概是说conv2d层(也就是卷积层)的输入和这一层不兼容,期望的维度是4,实际给的是3。
ValueError: Input 0 of layer conv2d is incompatible with the layer:
expected ndim=4, found ndim=3. Full shape received: [None, 28, 28]
可是还是不太清楚和报的错误有什么关系?
我们看看fashion数据集的维度,采用如下命令,运行后结果是x_train.shpae (60000, 28, 28),发现是3个维度的。
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
print("x_train.shpae", x_train.shape)
再运行如下命令,来看看cifar10数据集的维度,运行后的结果是x_train.shpae (50000, 32, 32, 3),发现这个数据集是4个维度的。到这里似乎知道为什么同样的模型用于cifar10数据集训练不会报错,但是用于fashion数据集就会报错,原来两个数据集看似区别不大,但是维度却实实实在在的不一样。
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print("x_train.shpae", x_train.shape)
到这里,我猜测卷积对于输入的维度是不是有要求的呢,3个维度的是不是不行呢?经过查询资料,发现确实如此。卷积计算要求输入的图片必须是4个维度的,第0个维度表示一次喂入几个batch,第1、2、3个维度分别表示输入图片的分辨率和通道数。这一部分详细且简单明了的解释可以参考下面的链接:
理解卷积网络的输入与输出形状(Keras实现)
现在情况越来越明朗,已经清楚知道错误的原因了:fashion数据集是3维的,不符合卷积的输入要求(4维),所以报Input 0 of layer conv2d is incompatible with the layer: expected ndim=4, found ndim=3. Full shape received: [None, 28, 28]。那怎么把fashion数据集变成4维的呢?
采用如下命令将3维的输入扩展维4维的输入,该命令简单明了的解释可以参考下方链接:
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)
TensorFlow笔记–用expand_dim()来增加维度
老师课上是以cifar10数据集为例来讲解卷积神经网络的实现,听完也懂了,自己也能像模像样地复现模型。但是举一反三,自己尝试用卷积神经网络训练fashion数据集的时候出现了这个错误,刚开始一直以为是自己的代码有微小的错误没有察觉出来,反复对照了很多次,花费了个把小时,发现同样的模型对于cifar10是可以的,换成fashion就不行了。捣鼓了一整个下午才发现错误的原因并解决。
深度学习没有看起来那么简单,基于Tensorflow这样的深度学习框架,上手并不复杂,但是如果对封装好的这些命令和深度学习理论没有很好的理解的话,是做不好这一块的。
现在只是万里长征的第一步,通过解决这个错误,自己又扎实掌握了一个卷积计算的知识,并在在排查错误的过程中复习了很多其他的知识。更大的收获是不再害怕错误了,通过错误的提示相应的调试和查找资料,是一定可以解决的。
汽车工程师,跨行学习深度学习,坚持不易,点赞鼓励一下吧。