Pytorch:分割数据集

from torch.utils.data import random_split, Subset

dataSet_length = len(data_set)

train_size= int(0.8*dataSet_length)
train_set = Subset(data_set, range(train_size))
test_set = Subset(data_set, range(train_size, dataSet_length))

你可能感兴趣的:(Pytorch,pytorch)