转载文章:原文链接:https://blog.csdn.net/hxxjxw/article/details/119531239
选择最合适的num_workers值
最合适的num_works值与数据集有关
from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
transform = transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
trainset = torchvision.datasets.MNIST(
root='dataset/',
train=True, #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform
)
print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):
train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
start = time()
for epoch in range(1, 3):
for i, data in enumerate(train_loader, 0):
pass
end = time()
print("Finish with:{} second, num_workers={}".format(end - start, num_workers))
————————————————
版权声明:本文为CSDN博主「hxxjxw」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/hxxjxw/article/details/119531239