在设计网络是,前面几层是去噪网络,后边几层是分类网络,因为之前没有接触过分类任务,对全连接层输入维度不太理解,出现错误RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x256 and 8x256)
解决方法:查看上一层卷积的输出值大小,发现
原因:
卷积层的输入为四维[batch_size,channels,H,W] ,而全连接接受维度为2的输入,通常为[batch_size, size]。
所以需要进行变换
添加以下语句:
x = x.view(x.shape[0], -1)
得到大小为([8, 256])
而对于fc层需要根据上面的输出更改输入,及将下面语句的8改为256,跑通
self.fc1 = nn.Linear(8, 256)