Pytorch样本比例不均衡时采用WeightedRandomSampler进行采样

做一个分类任务,样本比例不均匀,最大类与最小类差距有上百倍,因此要么用分层采样,要么用pytorch的torch.utils.data下提供的方法:
WeightedRandomSampler(weights: Sequence[float], num_samples: int, replacement: bool = True, generator=None)

对不同类的样本赋予权重,然后进行权重采样:

class_counts = torch.tensor([104, 642, 784])

# Create dummy data with class imbalance 99 to 1
class_counts = torch.tensor([104, 642, 784])
numDataPoints = class_counts.sum()
data_dim = 5
bs = 170
data = torch.randn(numDataPoints, data_dim)

target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),
                    torch.ones(class_counts[1], dtype=torch.long),
                    torch.ones(class_counts[2], dtype=torch.long) * 2))

print('target train 0/1/2: {}/{}/{}'.format(
    (target == 0).sum(), (target == 1).sum(), (target == 2).sum()))

# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])

# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataset = torch.utils.data.TensorDataset(data, target)
#train_dataset = triaxial_dataset(data, target)
train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=0, sampler=sampler)

# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1/2: {}/{}/{}".format(
        i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))

> target train 0/1/2: 104/642/784
batch index 0, 0/1/2: 52/60/58
batch index 1, 0/1/2: 63/60/47
batch index 2, 0/1/2: 62/58/50
batch index 3, 0/1/2: 59/60/51
batch index 4, 0/1/2: 45/65/60
batch index 5, 0/1/2: 59/60/51
batch index 6, 0/1/2: 54/56/60
batch index 7, 0/1/2: 59/60/51
batch index 8, 0/1/2: 57/64/49

几个要点:首先,创建sampler对象时传进去的sampler_weight参数的长度是len(samples)即样本的个数,每一个样本都对应一个权重,传进去的是一个权重序列长度为样本数。
另外,注意权重是不同类别样本的倒数(!!!这点坑死我了,一直把样本个数赋给每个样本,导致精度特别低!!!)

参考:

  1. Some problems with WeightedRandomSampler
  2. pytorch doc - WeightedRandomSampler
  3. How to Prevent Overfitting
  4. Pytorch学习(二十七)-------- 针对不均衡数据集的重采样Resample

一个更简单的例子:

batch_size = 20
class_sample_count = [10, 1, 20, 3, 4] # dataset has 10 class-1 samples, 1 class-2 samples, etc.
weights = 1 / torch.Tensor(class_sample_count)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size) # 注意这里的weights应为所有样本的权重序列,其长度为所有样本长度。
trainloader = data_utils.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, sampler = sampler)

你可能感兴趣的:(pytorch)