从网上找的现成的代码,跑cifar10数据集没有问题,但是跑mnist数据集反而出了问题,是因为图片一个是彩色一个是灰度,在深度学习的时候,输入的channel不一样,找了很多资料,都是说让改成transforms.Lambda(lambda x: x.repeat(1,1,1)),还是不行,后来请教了同学,改成了transforms.Lambda(lambda x: x.repeat(3,1,1)),就OK了
第二种解决的办法相对来说稍微复杂一点,因为这里出错主要是因为模型不匹配,用的是resnet18这个模型,它的模型结构第一层就是
可以看到输入就要求是3通道,所以可以通过改写模型的方式,将输入直接改为1通道:
这个时候再去训练黑白图片也不会出错了。
当时改写的时候,还出了一点点小问题,报错为:Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
这是因为我用的是GPU,改写的时候要在改写层后面加上.cuda()
self.global_model.conv1 = nn.Conv2d(in_channels=1,out_channels=64,kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).cuda()
然后就没有问题了