RuntimeError: mat1 and mat2 shapes cannot be multiplied (3584x7 and 25088x4096)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3584x7 and 25088x4096)

使用VGG19提取图像特征时出现该问题

报错代码

output = self.features(x)  # 输出维度为(512*7*7)
output = self.avgpool(output)
output = self.classifier(output)

原因分析

卷积层的输入为四维[batch_size,channels,H,W] ,而全连接层接受维度为2的输入,通常为[batch_size, size]

解决方案

在全连接层前加入维度变化
使用torch.flatten()

output = self.features(x)  # 输出维度为(512*7*7)
output = self.avgpool(output)
output = torch.flatten(output, 1)
output = self.classifier(output)

还看到一种解决方案

x.view(-1,7* 7* 1024) 

你可能感兴趣的:(bug记录,深度学习,人工智能,机器学习)