实现第一个神经网络之Dataset和Dataloader

构建机器学习解决方案的主要分解活动如下:
准备数据:get_data函数准备输入和输出张量。
创建学习参数:get_weights提供以随机值初始化的张量,网络通过优化这些参数来解决问题。
网络模型:simple_network函数应用线性规则为输入数据生成输出,计算时先用权重w乘以输入数据,再加上偏差b。即y=wx+b。
损失函数:loss_fn函数提供了评估模型优劣的信息。
优化器:optimize函数用于调整初始的随机权重,并帮助模型更准确地计算目标值。

1 准备数据

Pytorch提供了两种类型的数据抽象,称为张量和变量。张量类似于Numpy中的数组,它们可以在GPU上使用,并能改善性能。张量的具体概念可参考我的另一篇博客pytorch框架学习总结(一)
为深度学习算法准备数据本身就可能是一件复杂的事情,pytorch提供了很多工具类,工具类通过多线程、数据增强和批处理抽象出了如数据并行化等复杂性。本博客主要介绍两个重要的工具类:Dataset类和Dataloader类。从Kaggle网站Kaggle
上拿到Dog vs.Cats数据集,并创建可以生成pytorch张量形式的批图片的数据管道。
数据的相关处理主要保存在 data/dataset.py中。关于数据加载的相关操作,其基本原理就是使用Dataset进行数据集的封装,再使用Dataloader实现数据并行加载。

1.1 Dataset类

任何一个自定义的数据集类,都要继承自pytorch的数据集类,自定义的类必须实现两个函数:len(self)和__getitem__(self,idx)。任何和Dataset类表现类似的自定义类都应和下面的代码类似:

from torch.utils.data import Dataset
class DogsAndCatsDataset(Dataset):
    def __init__(self,):
        pass
    def __len__(self):
        pass
    def __getietm__(self,idx):
        pass

在init方法中,将进行任何需要的初始化。例如在本例中,读取索引和图片的文件名。len(self)运算负责返回数据集中的最大元素个数。getitem(self,idx)运算根据每次调用的idx返回对应元素。下面的代码实现了DogsAndCatsDataset类。

from torch.utils.data import Dataset
import numpy as np
import glob
#glob模块是最简单的模块之一,内容非常少。用它可以查找符合特定规则的文件路径名。
#跟使用windows下的文件搜索差不多。查找文件只用到三个匹配符:””, “?”, “[]”。
class DogsAndCatsDataset(Dataset):
    def __init__(self,root_dir,size=(224,225)):
        self.files=glob(root_dir)
        self.size=size
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        img=np.asarray(Image.open(self.files[idx]).resize(self.size))
        label=self.files[idx].split('/')[-2]
        return img,label

在定义了DogsAndCatsDataset类之后,可以创建一个对象并在其上进行迭代,如下代码所示:

for image,label in dogsdset
#在数据集上应用深度学习算法

在单个的数据实例上应用深度学习算法并不理想。我们需要一批数据,现在的GPU都对批数据的执行进行了性能优化。Dataloader类通过提取出大部分复杂度来帮助创建批数据

1.2 Dataloader类

Dataloader类位于Pytorch的utils类中,它将数据集对象和不同的取样器联合,如SequentialSampler和RandomSampler,并使用单线程或者多线程的迭代器,为我们提供批量图片。取样器是为算法提供数据的不同策略。下面是使用Dataloader处理Dogs vs. Cats数据集的例子。

dataloader=Dataloader(dogsdset,batch_size=32,num_workers=2)
for imgs,labels in dataloader:
    #在数据集上应用深度学习
    pass
    #imgs包含一个形状为(32,224,224,3)的张量,其中32表示批尺寸

DataLoader是一个比较重要的类,它为我们提供的常用操作有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)。

https://blog.csdn.net/zw__chen/article/details/82806900
《Pytorch深度学习》 [印度] Vishu Subramanian

你可能感兴趣的:(pytorch学习)