pytorch与jittor计图的数据集读写比较

pytorch与jittor计图的数据集读写比较

  • pytorch
  • jittor

pytorch

pytorch分为三级结构:分别是 Dataset,Dataloader,train/test_loader.
其中第一级别Dataset:是提供一种方式来获取数据和对应的label,Dataset 类用于加载和处理数据集,并将其转换为可供模型使用的格式。如果是非内置的数据集则需要我们继承Dataset类,实现自己的数据集类:

from torch.utils.data import Dataset

# 一般是继承Dataset类,然后设置这三个基本的函数。
class mydata(Dataset):

    #设置全局参数,类的构造函数,用于初始化数据集的相关参数。
    # 函数内容非固定,可自由设置
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        
        
    #获取每一个图片
    # 根据索引 index 返回数据集中的一个样本。通常该函数会对样本进行预处理,并将其转换为模型所需的格式。同时,如果有标签,一般也返回其对应的标签。
    def __getitem__(self, index):
        #获取单张图片名称
        img_name=self.img_path[index]
        #获取单张图片相对路径
        # 获取图片路径
        return img_tensor, label

    #返回数据集中的样本数量。
    def __len__(self):
        return len(self.img_path)

其中第二级别是Dataloader:DataLoader 是一个用于数据加载和批量处理的工具类。DataLoader 可以将 Dataset 中的样本按照指定的批量大小进行分组,并在迭代过程中自动加载和处理数据,以提高训练效率。

  1. dataset:要加载的数据集。
  2. batch_size:每个批次中样本的数量
  3. shuffle:是否在每个 epoch 开始前打乱数据集中的样本顺序。
  4. sampler:用于指定样本采样的策略。如果 sampler 不为 None,则 shuffle 参数将被忽略。
  5. num_workers:用于数据加载的子进程数量。可以根据系统的配置和数据集的大小来调整该参数以提高数据加载效率。
  6. drop_last:如果数据集的样本数量不能被批次大小整除,是否舍弃最后一批次中的剩余样本。
from torch.utils.data import DataLoader
test_loader=DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

第三级别为train/test_loader(名字可自拟):
DataLoader 的返回值是一个可迭代对象,每次迭代返回一个批次的数据。具体来说,每个批次返回一个长度为 batch_size 的元组,其中包含 batch_size 个样本和对应的标签。这个元组可以进一步分解为输入数据和标签。
因此train/test_loader是一个可迭代对象,用于逐个batch训练数据

jittor

jittor分为两级结构,为Dataset和train/test_loader

其中第一级别 Dataset同pytorch类似,自定义数据集也是对jittor中Dataset的继承,并编写__init__函数和__getitem__函数,以及__len__函数,作用同上。
第二级别为train/test_loader
相较于pytorch少了一个Dataloader的环节,也就是jittor中的Dataset就包含了Dataloader的功能。

你可能感兴趣的:(笔记,python)