mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
手动下载地址:https://github.com/zalandoresearch/fashion-mnist/blob/master/README.zh-CN.md
大无语事件,虽然说官方下载地址这里明说了这个数据集是集成在pytorch里了,而且查到的torchvision.dataset,也说有这么个函数,不过这个函数里没有定义任何功能函数,这可能就是加载失败的问题所在:
事实胜于雄辩,自动下载数据集失败了:
AttributeError: module 'torchvision.datasets' has no attribute 'FashionMNIST'
好吧,那就是没有了吧。数据集也不大,手动下载:
加载方式就要相应改变,参考代码(https://blog.csdn.net/CBCZJL/article/details/104414904):
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("..")
# 打开读取压缩文件
import gzip
import os
import numpy as np
def data_load(path, kind):
images_path = os.path.join(path,'%s-images-idx3-ubyte.gz' % kind)
labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz' % kind)
with gzip.open(labels_path,'rb') as lbpath:
labels = np.frombuffer(lbpath.read(),dtype=np.uint8, offset=8)
with gzip.open(images_path,'rb') as impath:
images = np.frombuffer(impath.read(),dtype=np.uint8, offset=16).reshape(len(labels),784)
return images, labels
# 读取转化数据
X_train, y_train = data_load('C:/Users/CSS/Documents/jupyter-file/Fashion_mnist_dataset','train')
X_test, y_test = data_load('C:/Users/CSS/Documents/jupyter-file/Fashion_mnist_dataset','t10k')
X_train_tensor = torch.Tensor(X_train).reshape(-1,1,28,28)*(1/255)
X_test_tensor = torch.from_numpy(X_test).to(torch.float32).view(-1,1,28,28)*(1/255)
y_train_tensor = torch.from_numpy(y_train).to(torch.float32).view(-1,1)
y_test_tensor = torch.from_numpy(y_test).to(torch.float32).view(-1,1)
mnist_train = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)
mnist_test = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)
batch_size = 256
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
输出:
60000 10000
加载过程中唯一有与参考代码不同的地方是这行:X_train_tensor = torch.Tensor(X_train).reshape(-1,1,28,28)*(1/255)
原文是:X_train_tensor = torch.from_numpy(X_train).to(torch.float32).view(-1,1,28,28)*(1/255)
当时运行时报错了,大概就是numpy与tensor的转换问题。
为啥X_test_tensor没报错,我没看明白,如有大佬明白,请不吝赐教。