Pytorch | dataloader 多线程下numpy每个线程随机种子都一样解决方案。

问题描述

  • pytorch的Dataloader用于加载数据。在num_works >1时, 每个线程中numpy.random产生的随机数
    一样,也就是随机种子相同。random 和 torch两个模块的随机数不会出现这种情况。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)
  • 复现问题代码
from torch.utils.data import  Dataset
from torch.utils.data import  DataLoader
import numpy as np
class NthreadDateset(Dataset):
    def __init__(self):
        self.datas = np.arange(100)

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, index):
        data = self.datas[index]
        random_data = np.random.uniform(0.0, 1.0)

        return  data, random_data

if __name__ == '__main__':
    datasets = NthreadDateset()
    data_loader = DataLoader(datasets,
                             num_workers=4,
                             shuffle=True,
    )
    for i, data in enumerate(data_loader):
        print(data)

运行结果

[tensor([5]), tensor([0.8464], dtype=torch.float64)]
[tensor([92]), tensor([0.8464], dtype=torch.float64)]
[tensor([46]), tensor([0.8464], dtype=torch.float64)]
[tensor([44]), tensor([0.8464], dtype=torch.float64)]
[tensor([69]), tensor([0.9780], dtype=torch.float64)]
[tensor([60]), tensor([0.9780], dtype=torch.float64)]
[tensor([53]), tensor([0.9780], dtype=torch.float64)]
[tensor([12]), tensor([0.9780], dtype=torch.float64)]
[tensor([0]), tensor([0.6385], dtype=torch.float64)]
[tensor([33]), tensor([0.6385], dtype=torch.float64)]
[tensor([85]), tensor([0.6385], dtype=torch.float64)]
[tensor([96]), tensor([0.6385], dtype=torch.float64)]

如结果所示,四个线程的随机数都一样 0.8464, 0.9780…

解决方案

  • 官方推荐:通过新建worker_init_fn 来设置不同线程的种子。
By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily). However, seeds for other libraries may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See this section in FAQ.).

In worker_init_fn, you may access the PyTorch seed set for each worker with either torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed other libraries before data loading.
  • 具体实现
    • 定义worker_init_fn函数, 让每个线程numpy.random的种子不同。
def worker_init_fn_seed(worker_id):
    seed = 10
    seed += worker_id
    np.random.seed(seed)
	
    print(worker_id)
if __name__ == '__main__':

    datasets = NthreadDateset()
    data_loader = DataLoader(datasets,
                             num_workers=4,
                             shuffle=True,
                             worker_init_fn= worker_init_fn_seed

    )
    for i, data in enumerate(data_loader):
        print(data)
  • 结果,每个线程内都随机
[tensor([94]), tensor([0.7713], dtype=torch.float64)]
[tensor([87]), tensor([0.1803], dtype=torch.float64)]
[tensor([47]), tensor([0.1542], dtype=torch.float64)]
[tensor([90]), tensor([0.7777], dtype=torch.float64)]
[tensor([66]), tensor([0.0208], dtype=torch.float64)]
[tensor([65]), tensor([0.0195], dtype=torch.float64)]
[tensor([97]), tensor([0.7400], dtype=torch.float64)]
[tensor([57]), tensor([0.2375], dtype=torch.float64)]
[tensor([59]), tensor([0.6336], dtype=torch.float64)]
[tensor([21]), tensor([0.4632], dtype=torch.float64)]
[tensor([7]), tensor([0.2633], dtype=torch.float64)]
[tensor([35]), tensor([0.8243], dtype=torch.float64)]

你可能感兴趣的:(深度学习,Pytorch)