关于torchvision加载数据集的小问题

现有以下完整程序可以成功加载数据集,使用ImageFolder函数:

import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms,utils
import numpy as np
# 使用ImageFolder需要保证数据集以下列形式组织:
'''
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
'''
img_data = torchvision.datasets.ImageFolder(
    root = r'E:\机器学习数据集\flower_photos',
    transform=transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()])
        )
print('数据集类别:',img_data.classes)
print('数据集大小:',len(img_data))

# 使用torch.utils.data.DataLoader加载,形成一个DataLoader类实例
data_loader = torch.utils.data.DataLoader(img_data,batch_size=36, shuffle=True)
print(len(data_loader))

def imshow(img):
#    img = img / 2 + 0.5     # unnormalize
    img = torchvision.utils.make_grid(img, nrow=6)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title('Batch from dataloader')
    plt.xticks([])
    plt.yticks([])
    plt.show()

# get some random training images
dataiter = iter(data_loader)
images, labels = dataiter.next()
print(images.shape, labels)
# show images
imshow(images)

上面程序用了三种变换:Resize,Crop和ToTensor,问题就出现在这里了,

  1. 问题1:如果去掉Crop,只留下Resize和ToTensor,程序报错:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 408 and 341 in dimension 3 at ..\aten\src\TH/generic/THTensor.cpp:711
  1. 问题2 :去掉ToTensor或者将ToTensor放到前面而不是最后一项,程序报错:
TypeError: batch must contain tensors, numbers, dicts or lists; found 

其他情况不好有问题,先留着这两个问题,以后研究深入了再解答。

你可能感兴趣的:(机器学习)