不用自己写划分数据集的函数,pytorch
已经给我们封装好了,那就是torch.utils.data.random_split()
。
torch.utils.data.random_split(dataset, lengths, generator=
)
随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。
dataset
(Dataset) – 要划分的数据集。lengths
(sequence) – 要划分的长度。generator
(Generator) – 用于随机排列的生成器。代码:
import torch
from torch.utils.data import random_split
dataset = range(10)
train_dataset, test_dataset = random_split(
dataset=dataset,
lengths=[7, 3],
generator=torch.Generator().manual_seed(0)
)
print(list(train_dataset))
print(list(test_dataset))
输出:
[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]
torch.Generator().manual_seed(0)
和torch.manual_seed(0)
的效果相同,我们验证一下。
代码:
import torch
from torch.utils.data import random_split
dataset = range(10)
torch.manual_seed(0)
train_dataset, test_dataset = random_split(
dataset=dataset,
lengths=[7, 3]
)
print(list(train_dataset))
print(list(test_dataset))
输出:
[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]
https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split