pytorch之创建数据集

import torch
import torchvision
from torchvision import datasets,transforms
dataroot = "data/celeba"  # 数据集所在文件夹
# 创建数据集
dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

1)torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四个参数:

root:在root指定的路径下寻找图片
transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
target_transform:对label的转换
loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

原文链接:https://blog.csdn.net/weixin_40123108/article/details/85099449

Pytorch的torchvision模块中提供了一个dataset 包,它包含了一些基本的数据集如mnist、coco、imagenet和一个通用的数据加载器ImageFolder。不同文件夹下的图,会被当作不同的类,天生就用于图像分类任务。

imagefolder有3个成员变量:用绿色表示x为train时的image_datasets的属性

  1. self.classes:用一个list保存类名,就是文件夹的名字。如['green', 'normal', 'out', 'right']
  2. self.class_to_idx:类名对应的索引,可以理解为 0、1、2、3 等。如{'out': 2, 'green': 0, 'right': 3, 'normal': 1}
  3. self.imgs:保存(imgpath,class),是图片和类别的数组。如[('datasets/test_True_TrainTest/train/green/0000000012200roi_.jpg', 0), ... , ('datasets/test_True_TrainTest/train/right/0000000012980roi_.jpg', 3)]

2)torchvision.transforms

torchvision.transforms 模块提供了一般的图像转换操作类。

  • class torchvision.transforms.ToTensor

把shape=(H x W x C) 的像素值为 [0, 255] 的 PIL.Image 和 numpy.ndarray
转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]的 torch.FloatTensor。

  • class torchvision.transforms.Normalize(mean, std)

此转换类作用于torch.*Tensor。给定均值(R, G, B)和标准差(R, G, B),用公式channel = (channel - mean) / std进行规范化。

原文链接:https://blog.csdn.net/wsp_1138886114/article/details/83620869

 

3)torch.utils.data.DataLoader

将数据按照batch_size封装成Tensor

  • 1.dataset(Dataset),数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。
  • 2.batch_size (int, optional),批训练数据量的大小,根据具体情况设置即可。(默认:1)
  • 3.shuffle (bool, optional),打乱数据,一般在训练数据中会采用。(默认:False)
  • 4.sampler (Sampler, optional),从数据集中提取样本的策略。如果指定,“shuffle”必须为false。一般默认即可。
  • 5.batch_sampler (Sampler, optional),和batch_size、shuffle等参数互斥,一般用默认。
  • 6.num_workers,这个参数必须大于等于0,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。(默认:0)
  • 7.collate_fn (callable, optional),合并样本清单以形成小批量。用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。
  • 8.pin_memory (bool, optional):数据加载器将把张量复制到CUDA内存中,然后返回它们。也就是一个数据拷贝的问题。
  • 9.drop_last (bool, optional): 如果数据集大小不能被批大小整除,则设置为“true”以除去最后一个未完成的批。如果“false”那么最后一批将更小。(默认:false)
  • 10.timeout(numeric, optional):设置数据读取超时,但超过这个时间还没读取到数据的话就会报错。(默认:0)
  • 11.worker_init_fn (callable, optional)如果不是:none,则在种子设定之后和数据加载之前,将在每个工作进程上调用它,并输入工作进程ID([0,num_workers-1)(默认:None)

原文链接:https://blog.csdn.net/wsp_1138886114/article/details/84146704

你可能感兴趣的:(pytorch)