DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
class NthreadDateset(Dataset):
def __init__(self):
self.datas = np.arange(100)
def __len__(self):
return len(self.datas)
def __getitem__(self, index):
data = self.datas[index]
random_data = np.random.uniform(0.0, 1.0)
return data, random_data
if __name__ == '__main__':
datasets = NthreadDateset()
data_loader = DataLoader(datasets,
num_workers=4,
shuffle=True,
)
for i, data in enumerate(data_loader):
print(data)
运行结果
[tensor([5]), tensor([0.8464], dtype=torch.float64)]
[tensor([92]), tensor([0.8464], dtype=torch.float64)]
[tensor([46]), tensor([0.8464], dtype=torch.float64)]
[tensor([44]), tensor([0.8464], dtype=torch.float64)]
[tensor([69]), tensor([0.9780], dtype=torch.float64)]
[tensor([60]), tensor([0.9780], dtype=torch.float64)]
[tensor([53]), tensor([0.9780], dtype=torch.float64)]
[tensor([12]), tensor([0.9780], dtype=torch.float64)]
[tensor([0]), tensor([0.6385], dtype=torch.float64)]
[tensor([33]), tensor([0.6385], dtype=torch.float64)]
[tensor([85]), tensor([0.6385], dtype=torch.float64)]
[tensor([96]), tensor([0.6385], dtype=torch.float64)]
如结果所示,四个线程的随机数都一样 0.8464, 0.9780…
By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily). However, seeds for other libraries may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See this section in FAQ.).
In worker_init_fn, you may access the PyTorch seed set for each worker with either torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed other libraries before data loading.
def worker_init_fn_seed(worker_id):
seed = 10
seed += worker_id
np.random.seed(seed)
print(worker_id)
if __name__ == '__main__':
datasets = NthreadDateset()
data_loader = DataLoader(datasets,
num_workers=4,
shuffle=True,
worker_init_fn= worker_init_fn_seed
)
for i, data in enumerate(data_loader):
print(data)
[tensor([94]), tensor([0.7713], dtype=torch.float64)]
[tensor([87]), tensor([0.1803], dtype=torch.float64)]
[tensor([47]), tensor([0.1542], dtype=torch.float64)]
[tensor([90]), tensor([0.7777], dtype=torch.float64)]
[tensor([66]), tensor([0.0208], dtype=torch.float64)]
[tensor([65]), tensor([0.0195], dtype=torch.float64)]
[tensor([97]), tensor([0.7400], dtype=torch.float64)]
[tensor([57]), tensor([0.2375], dtype=torch.float64)]
[tensor([59]), tensor([0.6336], dtype=torch.float64)]
[tensor([21]), tensor([0.4632], dtype=torch.float64)]
[tensor([7]), tensor([0.2633], dtype=torch.float64)]
[tensor([35]), tensor([0.8243], dtype=torch.float64)]