output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]

下载MINST数据集,并查看一个批次的图像时,出现了如上的错误

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
data_train = datasets.MNIST(
    root = "./data/",
    transform = transform,
    train = True,
    download=True
)

data_test = datasets.MNIST(
    root = "./data/",
    transform = transform,
    train = False,
    download=True
)
data_loader_train = torch.utils.data.DataLoader(
    dataset = data_train,
    shuffle = True,
    batch_size = 64
)
data_loader_test = torch.utils.data.DataLoader(
    dataset = data_test,
    shuffle = True,
    batch_size = 64
)
images, labels = next(iter(data_loader_train))

解决方案:

transform = transforms.Compose([
	transforms.ToTensor(), 
	transforms.Normalize([0.5], [0.5])
])

你可能感兴趣的:(PyTorch)