Pytorch读取图片数据的规范操作

读取torchvision已有数据集创建Dataloader

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
#Compose串联多个图片变换操作
#toTensor将ndarray或者是image转换为(C,H,W)的tensor格式,并归一化/255.0
#将tensor数值由[0,1]拓展到[-1,1],image=(image-mean)/std,mean和std通过(0.5,0.5,0.5),(0.5,0.5,0.5)指定
transform_for_train = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomGrayscale(),#数据增强,随机翻转和裁剪
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_for_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform_for_train)
#Dataloader分批导入数据
#batch_size:每次导入多少行
#num_workers:导入数据的子进程数,0默认主进程导入
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform_for_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

使用dataloader传入数据

for index,i in enumerate(trainloader):
    X_train,labels =  i
#X_train是torch.tensor如果模型输入要求是nparray
    imgss = X_train
    imgss = imgss.numpy()
#X_train的output是(batch_size,channel,height,width),与模型要求输入不同则需要更改    
    imgs = np.transpose(imgss, (0, 2, 3, 1))#(128,32,32,3)

通过ImageFolder创建数据集再通过Dataloader输出数据

import torch 
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,datasets 

transform_train = transforms.Compose(
    [transforms.ToTensor()])
transform_valid = transforms.Compose(
    [transforms.ToTensor()])
#datasets=datasets.ImageFolder(
#root : 在指定的root路径下面寻找图片 
# transform: 对PIL Image进行转换操作,transform 输入是loader读取图片返回的对象 
# target_transform :对label进行变换 
 #loader=default_loader: 指定加载图片的函数,默认操作是读取PIL image对象 )
 #datasets.class_to_idx:对应文件夹的label
 #datsets.imgs:所有图片的路径和对应的label(列表类型,元素是路径及label组成的元组)
 #dataset[0][1] 第二维度为1表示label
 #dataset[0][0] 第二维度为0表示图片

ds_train = datasets.ImageFolder("/home/kesci/input/data6936/data/cifar2/train/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("/home/kesci/input/data6936/data/cifar2/test/",
            transform = transform_train,target_transform= lambda t:torch.tensor([t]).float())
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True,num_workers=3)
dl_valid = DataLoader(ds_valid,batch_size = 50,shuffle = True,num_workers=3)            

你可能感兴趣的:(Pytorch读取图片数据的规范操作)