详细介绍torch中的from torch.utils.data.sampler相关知识

PyTorch中的torch.utils.data.sampler模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler类和函数的介绍:

  1. Sampler基类: Sampler是一个抽象类,它定义了一个__iter__方法,返回一个迭代器,用于生成数据集中的样本索引。
  2. RandomSampler: 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。
  3. SequentialSampler: 顺序采样器,它会按照数据集中的顺序,依次选择样本。
  4. SubsetRandomSampler: 子集随机采样器,它会从数据集的指定子集中随机选择样本。可以用于将数据集分成训练集和验证集等子集。
  5. WeightedRandomSampler: 加权随机采样器,它会根据指定的样本权重,进行随机采样。可以用于处理类别不平衡的问题。
  6. BatchSampler: 批次采样器,它会将样本索引分成多个批次,每个批次包含指定数量的样本索引。

这些Sampler类可以通过在DataLoader的构造函数中指定来使用。例如,可以使用RandomSampler来实现随机采样,使用SubsetRandomSampler来实现将数据集分成训练集和验证集。此外,还可以使用函数如WeightedRandomSampler来实现加权随机采样。

下面是使用上述Sampler类和函数的示例代码:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler

# 创建一个数据集
dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))

# 创建一个使用RandomSampler的DataLoader
random_loader = DataLoader(dataset, batch_size=2, sampler=RandomSampler(dataset))

# 创建一个使用SequentialSampler的DataLoader
seq_loader = DataLoader(dataset, batch_size=2, sampler=SequentialSampler(dataset))

# 创建一个使用SubsetRandomSampler的DataLoader
train_indices = [0, 1, 2, 3, 4]
val_indices = [5, 6, 7, 8, 9]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler)

# 创建一个使用WeightedRandomSampler的DataLoader
weights = [0.1, 0.9]
weighted_sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)
weighted_loader = DataLoader(dataset, batch_size=2, sampler=weighted_sampler)

# 使用BatchSampler将样本索引分成多个批次
batch_sampler = torch.utils.data.sampler.BatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=False)
batch_loader = DataLoader(dataset, batch_sampler=batch_sampler)

# 遍历DataLoader,输出每个批次的数据
for data, label in random_loader:
    print(data, label)
    
for data, label in seq_loader:
    print(data, label)
    
for data, label in train_loader:
    print(data, label)
    
for data, label in val_loader:
    print(data, label)
    
for data, label in weighted_loader:
    print(data, label)
    
for batch_indices in batch_sampler:
    batch_data = [dataset[idx] for idx in batch_indices]
    print(batch_data)

在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们创建了5个不同的DataLoader,每个DataLoader使用不同的采样器(RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler)来从数据集中选择样本。最后,我们遍历这些DataLoader,输出每个批次的数据。

可以通过继承Sampler基类来自定义采样函数。自定义采样函数需要实现__iter__方法和__len__方法。

__iter__方法需要返回一个迭代器,迭代器的每个元素都是数据集中的一个样本的索引。在这个方法中,可以自定义样本索引的选取方式,例如根据某种规则筛选样本或者将数据集分成多个子集。

__len__方法需要返回采样器的样本数量。如果采样器使用的是数据集的全部样本,则返回数据集的长度。

下面是一个自定义采样器的示例代码:

import torch
from torch.utils.data.sampler import Sampler

class CustomSampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        # 在初始化方法中,可以根据需要对数据集进行处理
    
    def __iter__(self):
        # 在这个方法中,可以自定义样本索引的选取方式
        # 这里的示例是随机选取样本
        indices = torch.randperm(len(self.data_source)).tolist()
        return iter(indices)
    
    def __len__(self):
        # 在这个方法中,需要返回采样器的样本数量
        # 这里的示例是采样器的样本数量等于数据集的长度
        return len(self.data_source)

在这个示例中,我们定义了一个名为CustomSampler的采样器类,它继承自Sampler基类。在初始化方法中,我们保存了数据集,并可以根据需要对数据集进行处理。在__iter__方法中,我们自定义了样本索引的选取方式,这里的示例是随机选取样本。在__len__方法中,我们返回了采样器的样本数量,这里的示例是采样器的样本数量等于数据集的长度。

使用自定义采样器时,只需要将它传入DataLoader的构造函数即可:

dataset = torch.utils.data.TensorDataset(torch.randn(10, 3), torch.randint(0, 2, (10,)))
custom_sampler = CustomSampler(dataset)
loader = DataLoader(dataset, batch_size=2, sampler=custom_sampler)

在这个示例中,我们首先创建了一个包含10个样本的TensorDataset。然后,我们使用CustomSampler创建了一个采样器,并将它传入DataLoader的构造函数。最后,我们遍历这个DataLoader,输出每个批次的数据。

你可能感兴趣的:(pytorch,深度学习,pytorch,人工智能)