【pytorch】使用torch.utils.data.random_split()划分数据集

写在前面

不用自己写划分数据集的函数,pytorch已经给我们封装好了,那就是torch.utils.data.random_split()

用法详解

torch.utils.data.random_split(dataset, lengths, generator=)

描述

随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。

参数

  • dataset (Dataset) – 要划分的数据集。
  • lengths (sequence) – 要划分的长度。
  • generator (Generator) – 用于随机排列的生成器。

示例

代码:

import torch
from torch.utils.data import random_split
dataset = range(10)
train_dataset, test_dataset = random_split(
    dataset=dataset,
    lengths=[7, 3],
    generator=torch.Generator().manual_seed(0)
)
print(list(train_dataset))
print(list(test_dataset))

输出:

[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]

torch.Generator().manual_seed(0)torch.manual_seed(0)的效果相同,我们验证一下。

代码:

import torch
from torch.utils.data import random_split
dataset = range(10)
torch.manual_seed(0)
train_dataset, test_dataset = random_split(
    dataset=dataset,
    lengths=[7, 3]
)
print(list(train_dataset))
print(list(test_dataset))

输出:

[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]

引用参考

https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split

你可能感兴趣的:(pytorch,pytorch,划分数据集)