完成一个深度学习的项目需要的四个步骤:
准备数据将遵循ETL过程
我们从数据源提取数据,再将数据转换为理想的格式,然后将数据加载到一个合适的结构中进行查询和分析,需要引入的包如下:
import torch
import torchvision
import torchvision.transforms as transforms
torchvision包让我们可以接触以下资源:
在准备数据时,最终的目标是遵循ETL过程:
在完成上述内容时,pytorch为我们提供了两个类:
Dataset类是表示数据集的抽象类,数据加载器封装数据集并提供对底层数据的访问。
由于需要使用数据集,所以我们需要一个继承了Dataset类的新类(非必须)来实现其中的抽象方法:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
import pandas as pd
import numpy as np
class OHLC(data.Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
def __getitem__(self, index):
r = self.data.iloc[index]
label = torch.tensor(r.is_up_day, dtype=torch.long)
sample = self.normailze(torch.tensor([r.open, r.high, r.low, r.close]))
return sample, label
def __len__(self):
return len(self.data)
以torchvision为基础的Fashion-MNIST数据集类继承了Dataset类并实现了其中的抽象方法。因此,在实际操作中,我们通常使用Fashion-MNIST数据集类即可。
MNIST数据集是经修改的国家标准与技术研究所数据库,是一个著名的手写数字数据集,通常用于机器学习的图像处理系统的培训。NIST代表国家标准与技术研究所。
MNIST中的M代表modified,这是因为有一个原始的NIST数字数据集被修改为MNIST。
MNIST因数据集的使用频率而闻名,其主要有两个原因:
该数据集由70000张手写数字图像组成,图像分割如下:
这些图片最初是由美国人口普查局员工和美国高中生创作的。
Fashion MNIST顾名思义就是一个时装项目的数据集。具体而言,数据集包含以下十类时尚项目:
数据集的示例如下所示:
Fashion-MNIST数据集名称中包含MNIST的原因是创建者试图用Fashion-MNIST替换MNIST。MNIST已经得到了广泛的应用,图像识别技术也有了很大的改进,以至于人们认为数据集过于简单。这就是创建Fashion MNIST数据集的原因。Fashion-MNIST的存在是为了让向PyTorch这样的框架可以通过改变获取数据的URL来添加Fashion-NMIST数据集,可以说PyTorch的Fashion-MNIST只是扩展了MNIST数据集并覆盖了其URL。
(1)提取和转换数据
import torch
import torchvision
import torchvision.transforms as transforms
train_set = torchvision.datasets.FashionMNIST(
root='./data/FashionMNIST' # 提取
,train=True
,download=True
,transform=transforms.Compose([ # 转换
transforms.ToTensor()
])
)
其中,第一个参数root是路径,这是数据所在磁盘的位置,第二个参数train设置为true,表明将数据用作训练集,第三个参数download设置为true,表明如果数据没有出现在指定的文件路径,则进行下载,最后一个参数transform,这里传递了一个转换组合,这些转换应该在数据集元素上执行,因为我们想把图像转换成张量,所以,我们用内置的ToTensor变换。
(2)在数据加载器对象中封装数据
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=10 # 加载
)
其中,batch_size指定批处理大小。
(3)访问训练集中的数据
import matplotlib.pyplot as plt
sample = next(iter(train_set))
image, label = sample
print(image.shape)
plt.imshow(image.squeeze(), cmap='gray')
显示结果:
torch.Size([1, 28, 28])
如果在jupyter notebook中使用plt.imshow函数,出现内核挂掉(内存不足),导致图片无法显示,可加入如下代码:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
上面是处理单个数据的情况,接下来是利用数据加载器,处理批量数据:
batch = next(iter(train_loader))
print(len(batch))
print(type(batch))
images, labels =batch
print(images.shape)
print(labels.shape)
显示结果:
2
torch.Size([10, 1, 28, 28])
torch.Size([10])
接着,可以使用torchvision.utils.make_grid函数一次性画出整批图像:
grid = torchvision.utils.make_grid(images, nrow=10) # nrow指定显示在每行的图片数(这个根据batch的大小来设置)
plt.figure(figsize=(15,15))
plt.imshow(np.transpose(grid, (1,2,0)))
print('labels:',labels)
显示结果: