读取torchvision已有数据集创建Dataloader
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
#Compose串联多个图片变换操作
#toTensor将ndarray或者是image转换为(C,H,W)的tensor格式,并归一化/255.0
#将tensor数值由[0,1]拓展到[-1,1],image=(image-mean)/std,mean和std通过(0.5,0.5,0.5),(0.5,0.5,0.5)指定
transform_for_train = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomGrayscale(),#数据增强,随机翻转和裁剪
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_for_test = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform_for_train)
#Dataloader分批导入数据
#batch_size:每次导入多少行
#num_workers:导入数据的子进程数,0默认主进程导入
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform_for_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
使用dataloader传入数据
for index,i in enumerate(trainloader):
X_train,labels = i
#X_train是torch.tensor如果模型输入要求是nparray
imgss = X_train
imgss = imgss.numpy()
#X_train的output是(batch_size,channel,height,width),与模型要求输入不同则需要更改
imgs = np.transpose(imgss, (0, 2, 3, 1))#(128,32,32,3)
通过ImageFolder创建数据集再通过Dataloader输出数据
import torch
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,datasets
transform_train = transforms.Compose(
[transforms.ToTensor()])
transform_valid = transforms.Compose(
[transforms.ToTensor()])
#datasets=datasets.ImageFolder(
#root : 在指定的root路径下面寻找图片
# transform: 对PIL Image进行转换操作,transform 输入是loader读取图片返回的对象
# target_transform :对label进行变换
#loader=default_loader: 指定加载图片的函数,默认操作是读取PIL image对象 )
#datasets.class_to_idx:对应文件夹的label
#datsets.imgs:所有图片的路径和对应的label(列表类型,元素是路径及label组成的元组)
#dataset[0][1] 第二维度为1表示label
#dataset[0][0] 第二维度为0表示图片
ds_train = datasets.ImageFolder("/home/kesci/input/data6936/data/cifar2/train/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("/home/kesci/input/data6936/data/cifar2/test/",
transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)