PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。 并且torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集,可通过torchvision.datasets方便的调用
import torch
torch.__version__
'1.2.0'
Dataset
Dataset是一个抽象类, 为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法:
__getitem__() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本
__len__() 该方法返回数据集的总长度
下面我们使用kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集,为了方便介绍,我们使用里面的数据字典来做说明(因为条数少)
from torch.utils.data import Dataset
import pandas as pd
class BulldozerDateset(Dataset): #继承Dataset类
"""数据集演示"""
def __init__(self,csv_file):
self.df = pd.read_csv(csv_file)
def __len__(self):
#返回df的长度
return len(self.df)
def __getitem__(self,idx):
# 根据idx返回一行数据
return self.df.iloc[idx]['x1'] #读取下标为idx的一行数据。可以在后面加 .列名 的方式读取属于某一列的数据,否则就会读取这一整行
data = BulldozerDateset(r'data/datatest_1.csv')
# 实现了 __len__ 方法所以可以直接使用len()获取数据总数
len(data)
25
#用索引可以直接访问对应数据,对应 __getitem__方法:
for i in range(len(data)):
print(data[i])
# x1 , x2 = data[i]
# print("x1 = ",x1,"x2 = ",x2)
0.232991543
0.449915356
0.840922298
0.20727367
0.541869015
0.36092917399999996
0.668949803
0.15037799400000001
0.898436358
0.302533521
0.286306281
0.8299904779999999
0.69587861
0.848728081
0.527168802
0.231401747
0.5379995989999999
0.6874250009999999
0.5887086970000001
0.252984848
0.067723107
0.220923151
0.49844191299999996
0.886522632
0.5958730729999999
DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程),下面做一个简单的操作
dl = torch.utils.data.DataLoader(data, batch_size= 5 ,shuffle= True , num_workers= 0)
print(dl)
DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据
idata = iter(dl)
print(next(idata))
tensor([0.2073, 0.6689, 0.2330, 0.5272, 0.3025])
对Dataloader常见的用法是使用for循环对其进行遍历
for i , value in enumerate(dl):
print(i, ": ",value)
0 : tensor([0.8984, 0.6689, 0.0677, 0.5380, 0.2530])
1 : tensor([0.8300, 0.5959, 0.4984, 0.4499, 0.3025])
2 : tensor([0.5887, 0.2863, 0.2330, 0.8409, 0.8865])
3 : tensor([0.6874, 0.1504, 0.2209, 0.8487, 0.6959])
4 : tensor([0.2073, 0.5272, 0.5419, 0.3609, 0.2314])
orchvision 是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程中最后的pip install torchvision 就是安装这个包。
torchvision.datasets
torchvision.datasets 可以理解为PyTorch团队自定义的dataset,这些dataset帮我们提前处理好了很多的图片数据集,我们拿来就可以直接使用:
MNIST
COCO
Captions
Detection
LSUN
ImageFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour 我们可以直接使用,示例如下:
import torchvision.datasets as datasets
train_set = datasets.MNIST(
root ='./data', #表示加载数据的目录
train = True , #表示是否加载数据集的训练集,false的话就是测试集
download = True, #表示是否自动下载数据集
transform = None , #表示是否对数据集预处理,如归一化。none为不进行预处理
)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz
99%|█████████████████████████████████████████████████████████████████████████████▌| 9.85M/9.91M [00:20<00:00, 707kB/s]
Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz
0.00B [00:00, ?B/s]
0%| | 0.00/28.9k [00:00, ?B/s]
28%|█████████████████████▊ | 8.19k/28.9k [00:00<00:00, 28.8kB/s]
32.8kB [00:00, 39.8kB/s]
Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
0.00B [00:00, ?B/s]
0%| | 0.00/1.65M [00:00, ?B/s]
1%|▊ | 16.4k/1.65M [00:00<00:31, 52.4kB/s]
3%|██▎ | 49.2k/1.65M [00:01<00:25, 62.8kB/s]
6%|████▌ | 98.3k/1.65M [00:01<00:19, 77.7kB/s]
12%|█████████▎ | 197k/1.65M [00:01<00:15, 93.4kB/s]
20%|███████████████▋ | 328k/1.65M [00:02<00:10, 123kB/s]
26%|████████████████████▊ | 434k/1.65M [00:02<00:07, 156kB/s]
33%|██████████████████████████▎ | 549k/1.65M [00:02<00:05, 193kB/s]
41%|████████████████████████████████▏ | 672k/1.65M [00:03<00:04, 234kB/s]
49%|██████████████████████████████████████▍ | 803k/1.65M [00:03<00:03, 278kB/s]
57%|████████████████████████████████████████████▋ | 934k/1.65M [00:03<00:02, 320kB/s]
65%|██████████████████████████████████████████████████▊ | 1.07M/1.65M [00:03<00:01, 362kB/s]
74%|█████████████████████████████████████████████████████████▎ | 1.21M/1.65M [00:04<00:01, 401kB/s]
82%|███████████████████████████████████████████████████████████████▉ | 1.35M/1.65M [00:04<00:00, 430kB/s]
91%|██████████████████████████████████████████████████████████████████████▉ | 1.50M/1.65M [00:04<00:00, 460kB/s]
100%|█████████████████████████████████████████████████████████████████████████████▉| 1.65M/1.65M [00:04<00:00, 478kB/s]
1.65MB [00:04, 335kB/s]
Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz
0.00B [00:00, ?B/s]
0%| | 0.00/4.54k [00:00, ?B/s]
8.19kB [00:00, 14.4kB/s]
Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!
9.92MB [00:40, 707kB/s]
torchvision.models
torchvision不仅提供了常用图片数据集,还提供了训练好的模型,可以加载之后,直接使用,或者在进行迁移学习 torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\Lowry/.cache\torch\checkpoints\resnet18-5c106cde.pth
0%| | 0.00/44.7M [00:00, ?B/s]
0%| | 64.0k/44.7M [00:00<01:20, 580kB/s]
0%|▎ | 176k/44.7M [00:00<01:09, 670kB/s]
1%|▋ | 368k/44.7M [00:00<00:56, 825kB/s]
1%|█ | 592k/44.7M [00:00<00:56, 815kB/s]
2%|█▉ | 1.09M/44.7M [00:00<00:42, 1.08MB/s]
3%|██▏ | 1.29M/44.7M [00:00<00:43, 1.05MB/s]
4%|██▊ | 1.61M/44.7M [00:01<00:35, 1.29MB/s]
4%|███ | 1.80M/44.7M [00:01<00:31, 1.42MB/s]
4%|███▍ | 2.00M/44.7M [00:01<00:30, 1.48MB/s]
5%|███▊ | 2.18M/44.7M [00:01<00:28, 1.55MB/s]
5%|████ | 2.36M/44.7M [00:01<00:28, 1.53MB/s]
6%|████▎ | 2.53M/44.7M [00:01<00:29, 1.51MB/s]
6%|████▋ | 2.71M/44.7M [00:01<00:29, 1.50MB/s]
7%|█████ | 2.91M/44.7M [00:01<00:27, 1.60MB/s]
7%|█████▎ | 3.09M/44.7M [00:02<00:27, 1.59MB/s]
7%|█████▋ | 3.29M/44.7M [00:02<00:28, 1.55MB/s]
8%|██████ | 3.49M/44.7M [00:02<00:26, 1.65MB/s]
8%|██████▎ | 3.68M/44.7M [00:02<00:25, 1.69MB/s]
9%|██████▋ | 3.91M/44.7M [00:02<00:24, 1.76MB/s]
9%|███████ | 4.11M/44.7M [00:02<00:23, 1.82MB/s]
10%|███████▍ | 4.29M/44.7M [00:02<00:29, 1.42MB/s]
10%|███████▊ | 4.51M/44.7M [00:02<00:26, 1.57MB/s]
10%|████████ | 4.68M/44.7M [00:03<00:28, 1.47MB/s]
11%|████████▎ | 4.84M/44.7M [00:03<00:28, 1.48MB/s]
11%|████████▌ | 4.99M/44.7M [00:03<00:29, 1.42MB/s]
12%|████████▊ | 5.14M/44.7M [00:03<00:30, 1.36MB/s]
12%|█████████ | 5.28M/44.7M [00:03<00:30, 1.36MB/s]
12%|█████████▍ | 5.45M/44.7M [00:03<00:29, 1.42MB/s]
13%|█████████▋ | 5.60M/44.7M [00:03<00:28, 1.44MB/s]
13%|█████████▉ | 5.74M/44.7M [00:03<00:28, 1.41MB/s]
13%|██████████▏ | 5.90M/44.7M [00:03<00:28, 1.42MB/s]
14%|██████████▍ | 6.06M/44.7M [00:04<00:27, 1.47MB/s]
14%|██████████▋ | 6.23M/44.7M [00:04<00:26, 1.52MB/s]
14%|███████████ | 6.39M/44.7M [00:04<00:26, 1.50MB/s]
15%|███████████▎ | 6.55M/44.7M [00:04<00:26, 1.50MB/s]
15%|███████████▌ | 6.70M/44.7M [00:04<00:26, 1.50MB/s]
15%|███████████▉ | 6.89M/44.7M [00:04<00:25, 1.58MB/s]
16%|████████████▏ | 7.06M/44.7M [00:04<00:24, 1.60MB/s]
16%|████████████▌ | 7.26M/44.7M [00:04<00:23, 1.67MB/s]
17%|████████████▊ | 7.43M/44.7M [00:04<00:23, 1.66MB/s]
17%|█████████████ | 7.60M/44.7M [00:05<00:23, 1.66MB/s]
17%|█████████████▍ | 7.77M/44.7M [00:05<00:23, 1.67MB/s]
18%|█████████████▋ | 7.95M/44.7M [00:05<00:23, 1.66MB/s]
18%|██████████████ | 8.13M/44.7M [00:05<00:22, 1.69MB/s]
19%|██████████████▎ | 8.30M/44.7M [00:05<00:22, 1.68MB/s]
19%|██████████████▌ | 8.46M/44.7M [00:05<00:25, 1.50MB/s]
19%|██████████████▊ | 8.62M/44.7M [00:05<00:25, 1.50MB/s]
20%|███████████████ | 8.77M/44.7M [00:05<00:25, 1.47MB/s]
20%|███████████████▍ | 8.95M/44.7M [00:05<00:24, 1.54MB/s]
20%|███████████████▋ | 9.10M/44.7M [00:06<00:24, 1.53MB/s]
21%|███████████████▉ | 9.27M/44.7M [00:06<00:23, 1.56MB/s]
21%|████████████████▎ | 9.46M/44.7M [00:06<00:22, 1.63MB/s]
22%|████████████████▌ | 9.62M/44.7M [00:06<00:22, 1.61MB/s]
22%|████████████████▉ | 9.80M/44.7M [00:06<00:22, 1.64MB/s]
22%|█████████████████▏ | 9.98M/44.7M [00:06<00:23, 1.58MB/s]
23%|█████████████████▌ | 10.2M/44.7M [00:06<00:21, 1.66MB/s]
23%|█████████████████▊ | 10.3M/44.7M [00:06<00:21, 1.66MB/s]
24%|██████████████████▏ | 10.5M/44.7M [00:06<00:21, 1.68MB/s]
24%|██████████████████▌ | 10.7M/44.7M [00:07<00:20, 1.73MB/s]
25%|██████████████████▉ | 11.0M/44.7M [00:07<00:19, 1.82MB/s]
25%|███████████████████▏ | 11.2M/44.7M [00:07<00:18, 1.86MB/s]
25%|███████████████████▌ | 11.3M/44.7M [00:07<00:19, 1.82MB/s]
26%|███████████████████▉ | 11.6M/44.7M [00:07<00:18, 1.87MB/s]
26%|████████████████████▎ | 11.8M/44.7M [00:07<00:18, 1.89MB/s]
27%|████████████████████▌ | 11.9M/44.7M [00:07<00:22, 1.53MB/s]
27%|████████████████████▉ | 12.1M/44.7M [00:07<00:21, 1.57MB/s]
27%|█████████████████████▏ | 12.3M/44.7M [00:08<00:21, 1.57MB/s]
28%|█████████████████████▍ | 12.4M/44.7M [00:08<00:22, 1.49MB/s]
28%|█████████████████████▋ | 12.6M/44.7M [00:08<00:22, 1.47MB/s]
28%|█████████████████████▉ | 12.7M/44.7M [00:08<00:23, 1.44MB/s]
29%|██████████████████████▏ | 12.9M/44.7M [00:08<00:23, 1.41MB/s]
29%|██████████████████████▍ | 13.0M/44.7M [00:08<00:23, 1.42MB/s]
9.92MB [03:40, 45.1kB/s]███▊ | 13.2M/44.7M [00:08<00:22, 1.44MB/s]
Exception in thread Thread-6:
Traceback (most recent call last):
File "D:\anaconda\envs\python37-pytorch\lib\threading.py", line 926, in _bootstrap_inner
self.run()
File "D:\anaconda\envs\python37-pytorch\lib\site-packages\tqdm\_monitor.py", line 62, in run
for instance in self.tqdm_cls._instances:
File "D:\anaconda\envs\python37-pytorch\lib\_weakrefset.py", line 60, in __iter__
for itemref in self.data:
RuntimeError: Set changed size during iteration
100%|█████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:38<00:00, 1.22MB/s]
torchvision.transforms
transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强
from torchvision import transforms as transforms
transform = transforms.Compose([
transforms.RandomCrop(32,padding=4), #先在四周填充0,把图像随机材成32 * 32
transforms.RandomHorizontalFlip() , #图像一半概率翻转,一般概率不翻转
transforms.RandomRotation((-45,45)) ,#随机旋转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)) ##R,G,B每层的归一化用到的均值和方差
])