PyTorch 之 Dataset 类入门学习

PyTorch 之 Dataset 类入门学习

Dataset 类简介

  • PyTorch 中的 Dataset 类是一个抽象类,用来表示数据集。通过继承 Dataset 类可以进行自定义数据集的格式、大小和其它属性,供后续使用;
    PyTorch 之 Dataset 类入门学习_第1张图片

  • 可以看到官方封装好的数据集也是直接或间接的继承自 Dataset
    PyTorch 之 Dataset 类入门学习_第2张图片

自定义数据集逻辑

  • 继承 Dataset 类;
  • 重写 init():构造函数,可自定义数据读取方法以及进行数据预处理;
  • 重写 len():返回数据集大小;
  • 重写 getitem_():索引数据集中的某一个数据

代码实现

import torch
from torch.utils.data import Dataset


# 自定义数据集继承 pytorch 内置的 Dataset 类

class GreenDataset(Dataset):
    """
      重写构造函数
    Args:
       data_tensor 数据或数据集合
       target_tensor 数据标签或数据标签集合
    """

    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 重写 len 方法: return 数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 重写 getitem 方法:基于索引,return 对应的数据及其标签,组合成 1 个元组返回

    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]


def test_data_set():
    """
        自定义数据集测试
    """
    # 生成数据集和标签集 (数据元素长度=标签元素长度)

    # 10 行 3 列数据,可以理解为 10 个元素,每个元素是一维的 3个元素列表
    data_tensor = torch.randn(10, 3)

    # 对应方法 torch.randint(low, high, size)标签是 0 或 1 的 10 个元素
    # low ( int , optional ) – 要从分布中提取的最小整数。默认值:0
    # high ( int ) – 高于要从分布中提取的最高整数
    # size ( tuple ) – 定义输出张量形状的元组
    # 以下示例中 low 取默认值 0
    target_tensor = torch.randint(2, (10,))
    # 将数据封装成自定义数据集的 Dataset
    my_dataset = GreenDataset(data_tensor, target_tensor)
    # 调用方法:查看数据集大小
    print('dataset size info:', len(my_dataset))

    # 根据索引获取数据
    print('tensor_data[0]: ', my_dataset[0])
    # 打印数据集
    for i, my_dataset in enumerate(my_dataset):
        print('索引值:%s 数据:%s' % (i, my_dataset))


if __name__ == '__main__':
    test_data_set()

重点函数

  • torch.randn()
    在这里插入图片描述

  • torch.randint()

执行结果

PyTorch 之 Dataset 类入门学习_第3张图片

你可能感兴趣的:(Python,PyTorch,pytorch,学习,人工智能)