pytorch划分训练集、验证集与测试集(train_idx、val_idx、test_idx)

假设现在共有10个数据,然后按照5:3:2的比例划分数据。

import torch
import torch.utils.data as D

x = torch.Tensor([10 - x + 100 for x in range(10)])

train_idx, val_idx, test_idx = D.random_split(x, [5, 3, 2])

# random_split函数返回的是一些D.dataset.Subset类(包含两个属性)
tmp = D.dataset.Subset # ctrl+左键见D.dataset.Subset类源码内容
print(test_idx)
print(type(test_idx))

# D.dataset.Subset的第二个属性indices是一个list,保存相应的索引
print(train_idx.indices)
print(val_idx.indices)
print(test_idx.indices)
print(type(train_idx.indices))

# dataset是第一个属性,该例子中数据类型为Tensor,保存的原来未分割的数据
print(train_idx.dataset)
print(type(train_idx.dataset))

# 最终要使用的划分数据如下
print(x[train_idx.indices])
print(x[val_idx.indices])
print(x[test_idx.indices])



输出结果如下所示:




[4, 9, 3, 6, 0]
[7, 1, 8]
[5, 2]


tensor([110., 109., 108., 107., 106., 105., 104., 103., 102., 101.])



tensor([106., 101., 107., 104., 110.])
tensor([103., 109., 102.])
tensor([105., 108.])

End... 

你可能感兴趣的:(Python总结,pytorch,深度学习,机器学习)