PyTorch中的torch.utils.data.sampler
模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler
类和函数的介绍:
Sampler
基类: Sampler
是一个抽象类,它定义了一个__iter__
方法,返回一个迭代器,用于生成数据集中的样本索引。RandomSampler
: 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。SequentialSampler
: 顺序采样器,它会按照数据集中的顺序,依次选择样本。SubsetRandomSampler
: 子集随机采样器,它会从数据集的指定子集中随机选择样本。可以用于将数据集分成训练集和验证集等子集。WeightedRandomSampler
: 加权随机采样器,它会根据指定的样本权重,进行随机采样。可以用于处理类别不平衡的问题。BatchSampler
: 批次采样器,它会将样本索引分成多个批次,每个批次包含指定数量的样本索引。这些Sampler
类可以通过在DataLoader
的构造函数中指定来使用。例如,可以使用RandomSampler
来实现随机采样,使用SubsetRandomSampler
来实现将数据集分成训练集和验证集。此外,还可以使用函数如WeightedRandomSampler
来实现加权随机采样。
下面是使用上述Sampler
类和函数的示例代码:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler
# 创建一个数据集
dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))
# 创建一个使用RandomSampler的DataLoader
random_loader = DataLoader(dataset, batch_size=2, sampler=RandomSampler(dataset))
# 创建一个使用SequentialSampler的DataLoader
seq_loader = DataLoader(dataset, batch_size=2, sampler=SequentialSampler(dataset))
# 创建一个使用SubsetRandomSampler的DataLoader
train_indices = [0, 1, 2, 3, 4]
val_indices = [5, 6, 7, 8, 9]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler)
# 创建一个使用WeightedRandomSampler的DataLoader
weights = [0.1, 0.9]
weighted_sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)
weighted_loader = DataLoader(dataset, batch_size=2, sampler=weighted_sampler)
# 使用BatchSampler将样本索引分成多个批次
batch_sampler = torch.utils.data.sampler.BatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=False)
batch_loader = DataLoader(dataset, batch_sampler=batch_sampler)
# 遍历DataLoader,输出每个批次的数据
for data, label in random_loader:
print(data, label)
for data, label in seq_loader:
print(data, label)
for data, label in train_loader:
print(data, label)
for data, label in val_loader:
print(data, label)
for data, label in weighted_loader:
print(data, label)
for batch_indices in batch_sampler:
batch_data = [dataset[idx] for idx in batch_indices]
print(batch_data)
在这个示例中,我们首先创建了一个包含10个样本的TensorDataset
。然后,我们创建了5个不同的DataLoader
,每个DataLoader
使用不同的采样器(RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler)来从数据集中选择样本。最后,我们遍历这些DataLoader
,输出每个批次的数据。
可以通过继承Sampler
基类来自定义采样函数。自定义采样函数需要实现__iter__
方法和__len__
方法。
__iter__
方法需要返回一个迭代器,迭代器的每个元素都是数据集中的一个样本的索引。在这个方法中,可以自定义样本索引的选取方式,例如根据某种规则筛选样本或者将数据集分成多个子集。
__len__
方法需要返回采样器的样本数量。如果采样器使用的是数据集的全部样本,则返回数据集的长度。
下面是一个自定义采样器的示例代码:
import torch
from torch.utils.data.sampler import Sampler
class CustomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
# 在初始化方法中,可以根据需要对数据集进行处理
def __iter__(self):
# 在这个方法中,可以自定义样本索引的选取方式
# 这里的示例是随机选取样本
indices = torch.randperm(len(self.data_source)).tolist()
return iter(indices)
def __len__(self):
# 在这个方法中,需要返回采样器的样本数量
# 这里的示例是采样器的样本数量等于数据集的长度
return len(self.data_source)
在这个示例中,我们定义了一个名为CustomSampler
的采样器类,它继承自Sampler
基类。在初始化方法中,我们保存了数据集,并可以根据需要对数据集进行处理。在__iter__
方法中,我们自定义了样本索引的选取方式,这里的示例是随机选取样本。在__len__
方法中,我们返回了采样器的样本数量,这里的示例是采样器的样本数量等于数据集的长度。
使用自定义采样器时,只需要将它传入DataLoader
的构造函数即可:
dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))
custom_sampler = CustomSampler(dataset)
loader = DataLoader(dataset, batch_size=2, sampler=custom_sampler)
在这个示例中,我们首先创建了一个包含10个样本的TensorDataset
。然后,我们使用CustomSampler
创建了一个采样器,并将它传入DataLoader
的构造函数。最后,我们遍历这个DataLoader
,输出每个批次的数据。