sklearn手写体数据集 Dataloader分batch训练

from sklearn.datasets import load_digits
from torch.utils.data import DataLoader
import numpy as np

digits = load_digits()
img = digits['images']  #(1797,8,8)
img = img[:,np.newaxis,:,:] #(1797,1,8,8)
dataloader = DataLoader(img, batch_size=4, shuffle=True, num_workers=0, drop_last=True)       
for i,img in enumerate(dataloader):
    print(img.size())  ##(4,1,8,8)

你可能感兴趣的:(机器学习,python,计算机视觉,机器学习)