RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

  • 在跑Pytroch的MNIST手写识别例子时,碰到了shape不匹配的错误,错误指向:
images, labels = next(iter(data_loader_train)) 
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]_第1张图片
  • 在尝试过多次之后,发现错误并不是这一句引发的,而是因为图片格式是灰度图只有一个channel,需要变成RGB图才可以,所以将其中一行做了修改:
  • 修改前:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) 
  • 修改后:
# 引入库
import torch
from torchvision import datasets, transforms
import torchvision.transforms
from torch.autograd import  Variable
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Lambda(lambda x: x.repeat(3,1,1)),
     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
 ])   # 修改的位置

data_train=datasets.MNIST(root="./data", 
						transform=transform,
						train=True,
						download=True
                          )
data_test=datasets.MNIST(root="./data", 
						transform=transform, 
						train=False)
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,
                                              batch_size=64,
                                              shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,
                                             batch_size=64,
                                             shuffle=True)

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print([labels[i] for i in range(64)])
plt.imshow(img)
  • 结果:可以看到输出的首先是64张图片对应的标签,然后是64张图片的预览结果。[tensor(8), tensor(1), tensor(7), tensor(1), tensor(8), tensor(0), tensor(6), tensor(7), tensor(1), tensor(7), tensor(1), tensor(2), tensor(5), tensor(8), tensor(5), tensor(4), tensor(3), tensor(7), tensor(8), tensor(5), tensor(1), tensor(8), tensor(3), tensor(0), tensor(8), tensor(4), tensor(2), tensor(0), tensor(9), tensor(0), tensor(6), tensor(3), tensor(9), tensor(3), tensor(6), tensor(1), tensor(1), tensor(5), tensor(2), tensor(7), tensor(0), tensor(7), tensor(4), tensor(0), tensor(1), tensor(4), tensor(8), tensor(8), tensor(7), tensor(4), tensor(5), tensor(1), tensor(2), tensor(7), tensor(3), tensor(5), tensor(1), tensor(2), tensor(7), tensor(8), tensor(2), tensor(8), tensor(4), tensor(4)]
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]_第2张图片

注:也可以尝试码友Victor_Gui提出的解决方案:https://blog.csdn.net/qq_31829611/article/details/90200694

你可能感兴趣的:(Debug错误集锦)