pytorch中的Dataset和DataLoader

"""
    训练时,每次计算一批次的数据,然后更新一次神经网络的参数
    在代码实现中,会将数据设置为一个迭代器,每次循环给出一批次的数据
    pytorch中给出了Dataset,DataLoader两个接口,帮助我们实现数据迭代器
"""
import torch
from torch.utils.data import Dataset, DataLoader

test_list = [(1, 'dog'), (2, 'cat'), (3, 'pig'), (4, 'bird')]
BATCH_SIZE = 2

'''
Dataset 将原始数据转换为python可以识别的数据结构
1.是一个抽象类,所有自写的dataset都必须继承它
2.重写方法_len_用于返回数据的数量
3.子类必须重写方法_getitem_,用于获取数据的索引
'''


class DataSet(Dataset):
    def __init__(self, datalist):
        self.x = datalist

    def __len__(self):  # 该方法提供了dataset的大小
        return len(self.x)

    def __getitem__(self, item):  # 该方法支持从 0 到 len(self)的索引。在此方法中对数据进行处理,将数据中的整数变成tensor
        return torch.tensor(self.x[item][0]), self.x[item][1]


torch_dataset = DataSet(test_list)
for i in range(len(torch_dataset)):
    print(torch_dataset[i])

test_loader = DataLoader(
    # 从数据集中每次抽出batch size个样本
    dataset=torch_dataset,  # torch TensorDataset format
    batch_size=BATCH_SIZE,  # mini batch size
    shuffle=True,  # 要不要打乱数据 (打乱比较好)
    # num_workers=1,               # 多线程来读数据
)


def show_batch():
    for epoch in range(3):
        print(f'epoch:{epoch}')
        for step, (batch_x, batch_y) in enumerate(test_loader, start=0):
            print("step:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))


show_batch()

输出:

(tensor(1), 'dog')
(tensor(2), 'cat')
(tensor(3), 'pig')
(tensor(4), 'bird')
epoch: 0
step:0, batch_x:tensor([3, 2]), batch_y:('pig', 'cat')
step:1, batch_x:tensor([4, 1]), batch_y:('bird', 'dog')
epoch: 1
step:0, batch_x:tensor([3, 2]), batch_y:('pig', 'cat')
step:1, batch_x:tensor([1, 4]), batch_y:('dog', 'bird')
epoch: 2
step:0, batch_x:tensor([2, 4]), batch_y:('cat', 'bird')
step:1, batch_x:tensor([3, 1]), batch_y:('pig', 'dog')

Process finished with exit code 0

你可能感兴趣的:(pytorch,深度学习,python)