首先从字面意义上来理解TensorDataset和DataLoader,TensorDataset是个只用来存放tensor(张量)的数据集,而DataLoader是一个数据加载器,一般用到DataLoader的时候就说明需要遍历和操作数据了。TensorDataset(tensor1,tensor2)的功能就是形成数据tensor1和标签tensor2的对应,也就是说tensor1中是数据,而tensor2是tensor1所对应的标签。来个小例子:
from torch.utils.data import TensorDataset,DataLoader
import torch
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a,b)
# 切片输出
print(train_ids[0:4]) # 第0,1,2,3行
# 循环取数据
for x_train,y_label in train_ids:
print(x_train,y_label)
下面是对应的输出:
(tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[1, 2, 3]]), tensor([44, 55, 66, 44]))
===============================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
从输出结果我们就可以很好的理解,tensor型数据和tensor型标签的对应了,这就是TensorDataset的基本应用。接下来我们把构造好的TensorDataset封装到DataLoader来操作里面的数据:
# 参数说明,dataset=train_ids表示需要封装的数据集,batch_size表示一次取几个
# shuffle表示乱序取数据,设为False表示顺序取数据,True表示乱序取数据
train_loader = DataLoader(dataset=train_ids,batch_size=4,shuffle=False)
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
for i,data in enumerate(train_loader,1):
train_data, label = data
print(' batch:{0} train_data:{1} label: {2}'.format(i+1, train_data, label))
下面是对应的输出:
batch:1 x_data:tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[1, 2, 3]]) label: tensor([44, 55, 66, 44])
batch:2 x_data:tensor([[4, 5, 6],
[7, 8, 9],
[1, 2, 3],
[4, 5, 6]]) label: tensor([55, 66, 44, 55])
batch:3 x_data:tensor([[7, 8, 9],
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]) label: tensor([66, 44, 55, 66])
至此,TensorDataset和DataLoader的联合使用就介绍完了。我们再看一下这两种方法的源码:
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
# 由于此类内容过多,故仅列举了与本文相关的参数,其余参数可以自行去查看源码
class DataLoader(Generic[T_co]):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
The :class:`~torch.utils.data.DataLoader` supports both map-style and
iterable-style datasets with single- or multi-process loading, customizing
loading order and optional automatic batching (collation) and memory pinning.
See :py:mod:`torch.utils.data` documentation page for more details.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
"""
dataset: Dataset[T_co]
batch_size: Optional[int]
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False):
self.dataset = dataset
self.batch_size = batch_size
感谢大家的阅读~