【torch】pytorch label的one-hot转化

最近使用了一下pytorch,在将label转化成向量时总是报错

        train_data = torchvision.datasets.MNIST(
            root='./mnist/',  # 保存或者提取位置
            train=True,  # this is training data
            transform=torchvision.transforms.ToTensor(
            ),  # 转换 PIL.Image or numpy.ndarray 成
            # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
            download=self.DOWNLOAD_MNIST,  # 没下载就下载, 下载了就不用再下了
        )

        idx = torch.LongTensor(train_data.train_labels)
        print(idx)
        train_data.train_data = train_data.train_data.reshape(60000, 784)
        one_hot_label = torch.FloatTensor(60000,
                                          10).zero_().scatter_(1, idx, 1)
        train_data.train_label = one_hot_label

报错: Index tensor must either be empty or have same dimensions as output tensor at

scatter_()的参数index的维度应该与转化后的一致,所以使用view函数处理idx

        train_data = torchvision.datasets.MNIST(
            root='./mnist/',  # 保存或者提取位置
            train=True,  # this is training data
            transform=torchvision.transforms.ToTensor(
            ),  # 转换 PIL.Image or numpy.ndarray 成
            # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
            download=self.DOWNLOAD_MNIST,  # 没下载就下载, 下载了就不用再下了
        )

        idx = torch.LongTensor(train_data.train_labels).view(-1, 1)
        print(idx)
        train_data.train_data = train_data.train_data.reshape(60000, 784)
        one_hot_label = torch.FloatTensor(60000,
                                          10).zero_().scatter_(1, idx, 1)
        train_data.train_label = one_hot_label

你可能感兴趣的:(python,深度学习,python)