如何使用pytorch加载数据

首先打开pycharm,将环境设置成python3.7(pytorch)ps:我使用的是3.7版本,可以根据自己的选择去选择不同的版本。

在正式开始之前我们先了解一下pytorch中的数据集:

PyTorch 数据集(Dataset),数据读取和预处理是进行机器学习的首要操作,PyTorch提供了很多方法来完成数据的读取和预处理。有 DatasetTensorDatasetDataLoaderImageFolder,在本文我们将使用Dataset来进行数据的加载操作。

from torch.utils.data import Dataset

torch.utils.data 是代表这一数据的抽象类,你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__init__、__len____getitem__这个三个函数(ps:都是python中的魔法函数):

class Mydata(Dataset):
    #数据类初始化
    def __init__(self, root_dir, label_dir):
        #定义数据主文件夹地址
        self.root_dir = root_dir
        #定义标签地址
        self.label_dir = label_dir
        #地址的拼接
        self.path = os.path.join(self.root_dir, self.label_dir)
        #数据列表
        self.img_path = os.listdir(self.path)


    #得到单个数据
    def __getitem__(self, idx):
        #得到单一数据的图片名(这里是使用列表)
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        #使用PIL中的Image库打开文件
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

其中__init__():

__init__() 方法可以包含多个参数,但必须包含一个名为 self 的参数,且必须作为第一个参数。也就是说,类的构造方法最少也要有一个 self 参数,仅包含 self 参数的 __init__() 构造方法,又称为类的默认构造方法。在这里我们设置了两个参数:root_dir(为所有数据的根目录) 和 label_dir(标签目录) 

__getitem__():

凡是在类中定义了这个__getitem__ 方法,那么它的实例对象(假定为p),可以像这样

p[key] 取值,当实例对象做p[key] 运算时,会调用类中的方法__getitem__。

一般如果想使用索引访问元素时,就可以在类中定义这个方法(__getitem__(self, key) )。

__len__():

__len__():的作用是返回容器中元素的个数。

在明白以上这些知识以后,我们进行数据的读取:

首先先实例化数据集:

#创建两个实例
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)

然后进行数据的读取:

如何使用pytorch加载数据_第1张图片

ants
124

你可能感兴趣的:(pytorch,python,深度学习)