为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法: 1.
__getitem__()
该方法定义用索引(0
到len(self)
)获取一条数据或一个样本 2.__len__()
该方法返回数据集的总长度
import torch
from torch.utils.data import Dataset
import pandas as pd
#自定义一个数据集类,继承Dataset
class BluebookDataset(Dataset):
'''数据集演示'''
def __init__(self,csv_file):
'''初始化时将数据载入'''
self.df=pd.read_csv(csv_file)
def __len__(self):
return len(self.df)#获取长度
def __getitem__(self,idx):
#iloc[ : , : ],冒号前面的取行数,后面的取列数,左闭右开原则.
return self.df.iloc[idx].SalePrice #读取第idx行,SalePrice列的数据
ds_demo = BluebookDataset('F:\Desktop\median_benchmark.csv') #先下载对应的.CSV文件
print(len(ds_demo))
ds_demo[0]
结果:
11573
24000.0
DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、 shuffle(是否进行shuffle操作)、 num_workers(加载数据的时候使用几个子进程)。
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None)
参数解释:
#返回一个可迭代的对象
dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)
for i, data in enumerate(dl):
print(i,data)
# 为了节约空间,这里只循环三遍
if(i==2):
break
结果:
0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
1 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
2 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
torchvision.datasets 为PyTorch自定义的dataset,包括: - MNIST - COCO - Captions - Detection - LSUN - ImageFolder - Imagenet-12 - CIFAR - STL10 - SVHN - PhotoTour。
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
train=True, # 表示是否加载数据库的训练集,false的时候加载测试集
download=True, # 表示是否自动下载 MNIST 数据集
transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理
加载时遇到的问题:
ImportError: The _imaging extension was built for another version of Pillow or PIL: Core version: "9.2.0" Pillow version: 9.2.0
解决方法:
重新安装一下pillow就可以了
pip uninstall Pillow
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pillow
下载成功
下载常用的模型,包括:- AlexNet - VGG - ResNet - SqueezeNet - DenseNet。
#我们直接可以使用训练好的模型,当然这个与datasets相同,都是需要从服务器下载的
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强。
from torchvision import transforms as transforms
transform = transforms.Compose([ #串联多个图片变换的操作,即想要执行的transform操作。
transforms.RandomCrop(32, padding=4), #先四周填充4层0,在把图像随机裁剪成32*32像素大小
transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转
transforms.RandomRotation((-45,45)), #随机旋转
transforms.ToTensor(), #把图像转换为Tensor
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.229, 0.224, 0.225),inplace=False), #R,G,B每层的归一化用到的均值和方差,把图片3个通道中的数据整理到[-1, 1]区间,可以加快模型的收敛,mean和std可自己设定.
])
操作 | 功能 |
---|---|
transforms.CenterCrop(size)#size为裁剪大小,超过原图大小自动补0 | 中心裁剪 |
transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=‘constant’) #fill填充的像素大小,padding_mode:填充的模式(constant:填充fill设定的值, edge:填充边界的值, reflect or symmetric) | 随机裁剪 |
transforms.RandomResizedCrop(size, scale=(0.01, 1.0), ratio=(0.75, 1.4), interpolation=2)#scale:面积随机在(0.01, 1.0)之间的比例缩放,ratio:长宽比随机在(0.75, 1.4)之间选取。 | 随机大小、随机宽高比裁剪图片 |
transforms.FiveCrop(size, vertical_flip=False)#size:最后裁剪的图片尺寸,vertical_flip:是否翻转。最后的 tensor 形状是 [5crops, c, h, w] | 在图像的上下左右以及中心裁剪出尺寸为 size 的 5 张图片 |
transforms.RandomVerticalFlip(p=1) #p为翻转概率 | 水平或者垂直方向翻转图片 |
transforms.RandomRotation(degrees, resample=False, expand=False, center=None, fill=None) #degrees:旋转角度如(-45,60),resample是否重采样,expand:是否扩大矩形框,会改变图片的尺寸,center:旋转中心,默认是图片的中心。 | 随机旋转 |
transforms.Pad(padding=(16,12), fill=0, padding_mode=‘constant’)#padding:填充的大小,(16,12)表示上下填充16,左右填充12. | 图片填充 |
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)#参数分别为亮度,对比度,饱和度和色相。 | 调整亮度、对比度、饱和度、色相。 |
transforms.Grayscale(p=0.1, num_output_channels=3) #p:转为灰度图的概率,num_output_channels:输出通道数。 | 转灰度图 |
transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0) #分别为旋转,平移,缩放,填充颜色,错切角及采样设置( NEAREST、BILINEAR、BICUBIC。)。 | 仿射变换 |
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False) | 图像随机遮挡 |
transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5) | 根据概率执行或不执行一组 transforms 操作 |
transforms.RandomChoice([transforms1, transforms2, transforms3]) | 随机选一个执行 |
transforms.RandomOrder([transforms1, transforms2, transforms3]) | 打乱顺序执行一组 transforms 操作 |
参考资料:
https://handbook.pytorch.wiki/chapter2/2.1.4-pytorch-basics-data-loader.html
未完待续!
欢迎关注个人公众号【智能建造小硕】(分享计算机编程、人工智能、智能建造、日常学习和科研经验等,欢迎大家关注交流。)