Pytorch中TensorDataset的快速使用


作者学习记录方便查询


Pytorch中,TensorDataset()可以快速构建训练所用的数据,不用使用自建的Mydataset(),如果没有熟悉适用的dataset可以使用TensorDataset()作为暂时替代。
只需要把data和label作为参数输入,就可以快速构建,之后便可以用Dataloader处理。

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) 

你可能感兴趣的:(pytorch,深度学习,python)