Pytorch-怎么让两个dataloader打乱顺序相同,自定义一个sampler

思路:自定义一个sampler,将采样方式传进不同的dataloader,则取出的数据一致
代码:

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class MyDataset_v1(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]


class MyDataset_v2(Dataset):
    def __init__(self):
        self.data = [1.1, 2.2, 3.3, 4.4]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

class SamplerDef(object):

    def __init__(self, data_source, indices):
        self.data_source = data_source
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.data_source)

if __name__ == "__main__":
    myDataset1 = MyDataset_v1()
    myDataset2 = MyDataset_v2()
    dataloader = {}

    n = len(myDataset1)
    indices = torch.randperm(n)
    mySampler = SamplerDef(data_source=myDataset1, indices=indices)

    dataloader['v1'] = DataLoader(dataset=myDataset1, batch_size=2, shuffle=False, pin_memory=True, sampler=mySampler)
    dataloader['v2'] = DataLoader(dataset=myDataset2, batch_size=2, shuffle=False, pin_memory=True, sampler=mySampler)
    epoch = 2
    step = -1
    for i in range(epoch):
        for batch_ind, data in enumerate(zip(dataloader['v1'], dataloader['v2'])):
            d1, d2 = data[0], data[1]
            print("Epoch: {} Batch_ind: {} data in Dataset1: {} data in Dataset2: {}".format(i, batch_ind, d1, d2))

结果:
Epoch: 0 Batch_ind: 0 data in Dataset1: tensor([3, 1]) data in Dataset2: tensor([3.3000, 1.1000], dtype=torch.float64)
Epoch: 0 Batch_ind: 1 data in Dataset1: tensor([4, 2]) data in Dataset2: tensor([4.4000, 2.2000], dtype=torch.float64)
Epoch: 1 Batch_ind: 0 data in Dataset1: tensor([3, 1]) data in Dataset2: tensor([3.3000, 1.1000], dtype=torch.float64)
Epoch: 1 Batch_ind: 1 data in Dataset1: tensor([4, 2]) data in Dataset2: tensor([4.4000, 2.2000], dtype=torch.float64)
分析:
打乱后,每epoch的采样都是一致的,因为只有一个预定义的采样方式,请结合上一篇使用

你可能感兴趣的:(Pytorch,python,深度学习)