网络训练报错RuntimeError:size mismatch, m1: [256 x 1600], m2: [1024 x 256]的解决办法

最近在用fashionMNIST跑一个网络,结果如下图

网络训练报错RuntimeError:size mismatch, m1: [256 x 1600], m2: [1024 x 256]的解决办法_第1张图片

想要看看用同一个网络,不同数据集之间对准确率差别,因此我把数据集换成了cifar10.除了要用到transforms.Grayscale(1)

test_dataset = CIFAR10('../data/CIFAR10', train=False, download=True, transform=transforms.Compose([
    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize((mean,),(std,))
]))

将3通道的彩色图片转成单通道图片外,还报了一个错误:

RuntimeError:size mismatch, m1: [256 x 1600], m2: [1024 x 256]

网络训练报错RuntimeError:size mismatch, m1: [256 x 1600], m2: [1024 x 256]的解决办法_第2张图片

经过排查代码后发现是由于在全连接层时矩阵维度大小不匹配造成的,将维度64*4*4修改成1600:

网络训练报错RuntimeError:size mismatch, m1: [256 x 1600], m2: [1024 x 256]的解决办法_第3张图片

最后是cifar10经过20个epoch训练的结果,准确率并没有fashionMNIST高。猜测或许是因为cifar10数据量没有fashionMNIST大,还有一部分原因是因为cifar10有些数据本身难以识别,数据集的质量不高,因此需要通过一些技巧去提高准确率。如果有更加合理的解释,欢迎评论多多指教~~

fashionMNIST 训练集包含 60000 个样例,测试集包 含 10000 个样例,分为 10 类,每一类的样本训练样本数量和测试样
本数量相同。样本都来自日常穿着的衣裤鞋包,每个都是 28 × 28 的 灰度图像,其中总共有 10 类标签,每张图像都有各自的标签
CIFAR-10 数据集是一个接近普适物体的彩色图像数据集,该数 据集共有 60000 张彩色图像,其中有 50000 张训练图片和 10000 张 测试图片,图片尺寸为 32*32 ,分为 10 个类,每类 6000 张图片。)
 

网络训练报错RuntimeError:size mismatch, m1: [256 x 1600], m2: [1024 x 256]的解决办法_第4张图片

你可能感兴趣的:(人工智能)