解决类别不平衡问题一般的思路:
下面直奔主题:decouple the learning procedure into representation learning and classification
即,解耦成两个部分:表征学习和分类器。
standard instance-based sampling
标准基于实例的采样(standard instance-based sampling)是一种常用的采样策略,用于解决类别不平衡问题。在类别不平衡的数据集中,某些类别的样本数量远远多于其他类别,这可能导致模型对于少数类别的识别性能较差。
标准基于实例的采样通过对数据集中的样本进行重采样,以平衡不同类别的样本数量。具体而言,该采样策略会从多数类别中随机选择一定数量的样本,使其数量与少数类别相当
。这样可以避免模型过于偏向多数类别,提高对少数类别的识别能力。
标准基于实例的采样通常有两种方式:欠采样(undersampling)和过采样(oversampling)。
class-balanced sampling
类别平衡采样的目标是通过调整样本的采样权重,使得每个类别在采样过程中都有相似的机会被选中。这样可以避免模型过于偏向多数类别,提高对少数类别的学习能力。
以上两种方法混合
transferring features learned from head classes with abundant
training instances to under-represented tail classes
采样策略上的更新:
Classifier Re-training (cRT)
Nearest Class Mean classifier (NCM)
这个layer是一个模块化的注意力层,用于提取输入特征图的重要信息。它通过计算输入特征图的不同部分之间的关联性,以及对特征图进行空间注意力调整,来增强特征图中的有用信息。
- 具体来说,该层包括三个卷积操作(g、theta、phi)和一个卷积操作(conv_mask)。它通过计算输入特征图的三个不同表示(g、theta、phi),并使用这些表示计算特征图之间的关联性(map_t_p),然后将关联性与输入特征图中的空间注意力调整(mask)相乘,得到最终的特征图输出(final)。
- 该层还包括一个全连接层(fc_spatial),用于计算特征图的空间注意力权重(spatial_att)。最后,该层返回特征图输出和中间结果(x、spatial_att、mask)作为输出。
在这个模块化的注意力层中,卷积层的kernel_size被设置为1是为了保持特征图的空间维度不变。这是因为在这个层中,主要关注的是特征图之间的关联性和空间注意力调整,而不是特征图的空间维度变化。
因此,通过将kernel_size设置为1,可以确保卷积操作只对特征图的通道维度进行操作,而不改变其空间维度。这样可以保持特征图的空间结构不变,使得后续的操作能够更好地利用特征图中的信息。
import torch
from torch import nn
from torch.nn import functional as F
import pdb
class ModulatedAttLayer(nn.Module):
def __init__(self, in_channels, reduction = 2, mode='embedded_gaussian'):
super(ModulatedAttLayer, self).__init__()
self.in_channels = in_channels
self.reduction = reduction
self.inter_channels = in_channels // reduction
self.mode = mode
assert mode in ['embedded_gaussian']
self.g = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
self.theta = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
self.phi = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
self.conv_mask = nn.Conv2d(self.inter_channels, self.in_channels, kernel_size = 1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc_spatial = nn.Linear(7 * 7 * self.in_channels, 7 * 7)
self.init_weights()
def init_weights(self):
msra_list = [self.g, self.theta, self.phi]
for m in msra_list:
nn.init.kaiming_normal_(m.weight.data)
m.bias.data.zero_()
self.conv_mask.weight.data.zero_()
def embedded_gaussian(self, x):
# embedded_gaussian cal self-attention, which may not strong enough
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
map_t_p = torch.matmul(theta_x, phi_x)
mask_t_p = F.softmax(map_t_p, dim=-1)
map_ = torch.matmul(mask_t_p, g_x)
map_ = map_.permute(0, 2, 1).contiguous()
map_ = map_.view(batch_size, self.inter_channels, x.size(2), x.size(3))
mask = self.conv_mask(map_)
x_flatten = x.view(-1, 7 * 7 * self.in_channels)
spatial_att = self.fc_spatial(x_flatten)
spatial_att = spatial_att.softmax(dim=1)
spatial_att = spatial_att.view(-1, 7, 7).unsqueeze(1)
spatial_att = spatial_att.expand(-1, self.in_channels, -1, -1)
final = spatial_att * mask + x
return final, [x, spatial_att, mask]
def forward(self, x):
if self.mode == 'embedded_gaussian':
output, feature_maps = self.embedded_gaussian(x)
else:
raise NotImplemented("The code has not been implemented.")
return output, feature_maps
"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
import torch.nn as nn
import numpy as np
import pickle
from os import path
class KNNClassifier(nn.Module):
def __init__(self, feat_dim=512, num_classes=1000, feat_type='cl2n', dist_type='l2'):
super(KNNClassifier, self).__init__()
assert feat_type in ['un', 'l2n', 'cl2n'], "feat_type is wrong!!!"
assert dist_type in ['l2', 'cos'], "dist_type is wrong!!!"
self.feat_dim = feat_dim
self.num_classes = num_classes
self.centroids = torch.randn(num_classes, feat_dim)
self.feat_mean = torch.randn(feat_dim)
self.feat_type = feat_type
self.dist_type = dist_type
self.initialized = False
def update(self, cfeats):
mean = cfeats['mean']
centroids = cfeats['{}cs'.format(self.feat_type)]
mean = torch.from_numpy(mean)
centroids = torch.from_numpy(centroids)
self.feat_mean.copy_(mean)
self.centroids.copy_(centroids)
if torch.cuda.is_available():
self.feat_mean = self.feat_mean.cuda()
self.centroids = self.centroids.cuda()
self.initialized = True
def forward(self, inputs, *args):
centroids = self.centroids
feat_mean = self.feat_mean
# Feature transforms
if self.feat_type == 'cl2n':
inputs = inputs - feat_mean
#centroids = centroids - self.feat_mean
if self.feat_type in ['l2n', 'cl2n']:
norm_x = torch.norm(inputs, 2, 1, keepdim=True)
inputs = inputs / norm_x
#norm_c = torch.norm(centroids, 2, 1, keepdim=True)
#centroids = centroids / norm_c
# Logit calculation
if self.dist_type == 'l2':
logit = self.l2_similarity(inputs, centroids)
elif self.dist_type == 'cos':
logit = self.cos_similarity(inputs, centroids)
return logit, None
def l2_similarity(self, A, B):
# input A: [bs, fd] (batch_size x feat_dim)
# input B: [nC, fd] (num_classes x feat_dim)
feat_dim = A.size(1)
AB = torch.mm(A, B.t())
AA = (A**2).sum(dim=1, keepdim=True)
BB = (B**2).sum(dim=1, keepdim=True)
dist = AA + BB.t() - 2*AB
return -dist
def cos_similarity(self, A, B):
feat_dim = A.size(1)
AB = torch.mm(A, B.t())
AB = AB / feat_dim
return AB
def create_model(feat_dim, num_classes=1000, feat_type='cl2n', dist_type='l2',
log_dir=None, test=False, *args):
print('Loading KNN Classifier')
print(feat_dim, num_classes, feat_type, dist_type, log_dir, test)
clf = KNNClassifier(feat_dim, num_classes, feat_type, dist_type)
if log_dir is not None:
fname = path.join(log_dir, 'cfeats.pkl')
if path.exists(fname):
print('===> Loading features from %s' % fname)
with open(fname, 'rb') as f:
data = pickle.load(f)
clf.update(data)
else:
print('Random initialized classifier weights.')
return clf
if __name__ == "__main__":
cens = np.eye(4)
mean = np.ones(4)
xs = np.array([
[0.9, 0.1, 0.0, 0.0],
[0.2, 0.1, 0.1, 0.6],
[0.3, 0.3, 0.4, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.25, 0.25, 0.25, 0.25]
])
xs = torch.Tensor(xs)
classifier = KNNClassifier(feat_dim=4, num_classes=4,
feat_type='un')
classifier.update(mean, cens)
import pdb; pdb.set_trace()
logits, _ = classifier(xs)
"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from utils import *
from os import path
class DotProduct_Classifier(nn.Module):
def __init__(self, num_classes=1000, feat_dim=2048, *args):
super(DotProduct_Classifier, self).__init__()
# print('<DotProductClassifier> contains bias: {}'.format(bias))
self.fc = nn.Linear(feat_dim, num_classes)
self.scales = Parameter(torch.ones(num_classes))
for param_name, param in self.fc.named_parameters():
param.requires_grad = False
def forward(self, x, *args):
x = self.fc(x)
x *= self.scales
return x, None
def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, *args):
print('Loading Dot Product Classifier.')
clf = DotProduct_Classifier(num_classes, feat_dim)
if not test:
if stage1_weights:
assert(dataset)
print('Loading %s Stage 1 Classifier Weights.' % dataset)
if log_dir is not None:
subdir = log_dir.strip('/').split('/')[-1]
subdir = subdir.replace('stage2', 'stage1')
weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir)
# weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1')
else:
weight_dir = './logs/%s/stage1' % dataset
print('==> Loading classifier weights from %s' % weight_dir)
clf.fc = init_weights(model=clf.fc,
weights_path=path.join(weight_dir, 'final_model_checkpoint.pth'),
classifier=True)
else:
print('Random initialized classifier weights.')
return clf
class DiscCentroidsLoss(nn.Module):
def __init__(self, num_classes, feat_dim, size_average=True):
super(DiscCentroidsLoss, self).__init__()
self.num_classes = num_classes
self.centroids = nn.Parameter(torch.randn(num_classes, feat_dim))
self.disccentroidslossfunc = DiscCentroidsLossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average
def forward(self, feat, label):
batch_size = feat.size(0)
# calculate attracting loss
feat = feat.view(batch_size, -1)
# To check the dim of centroids and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's \
dim: {1}".format(self.feat_dim,feat.size(1)))
batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
loss_attract = self.disccentroidslossfunc(feat, label, self.centroids, batch_size_tensor).squeeze()
# calculate repelling loss
distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centroids, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, feat, self.centroids.t())
classes = torch.arange(self.num_classes).long().cuda()
labels_expand = label.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels_expand.eq(classes.expand(batch_size, self.num_classes))
distmat_neg = distmat
distmat_neg[mask] = 0.0
# margin = 50.0
margin = 10.0
loss_repel = torch.clamp(margin - distmat_neg.sum() / (batch_size * self.num_classes), 0.0, 1e6)
# loss = loss_attract + 0.05 * loss_repel
loss = loss_attract + 0.01 * loss_repel
return loss
class DiscCentroidsLossFunc(Function):
@staticmethod
def forward(ctx, feature, label, centroids, batch_size):
ctx.save_for_backward(feature, label, centroids, batch_size)
centroids_batch = centroids.index_select(0, label.long())
return (feature - centroids_batch).pow(2).sum() / 2.0 / batch_size
@staticmethod
def backward(ctx, grad_output):
feature, label, centroids, batch_size = ctx.saved_tensors
centroids_batch = centroids.index_select(0, label.long())
diff = centroids_batch - feature
# init every iteration
counts = centroids.new_ones(centroids.size(0))
ones = centroids.new_ones(label.size(0))
grad_centroids = centroids.new_zeros(centroids.size())
counts = counts.scatter_add_(0, label.long(), ones)
grad_centroids.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
grad_centroids = grad_centroids/counts.view(-1, 1)
return - grad_output * diff / batch_size, None, grad_centroids / batch_size, None
def create_loss (feat_dim=512, num_classes=1000):
print('Loading Discriminative Centroids Loss.')
return DiscCentroidsLoss(num_classes, feat_dim)
import torch.nn as nn
def create_loss ():
print('Loading Softmax Loss.')
return nn.CrossEntropyLoss()
循环遍历不同类别的样本,并按照指定数量从每个类别中取样本。它使用了两个迭代器,一个用于循环遍历不同类别的样本,另一个用于循环遍历每个类别中的样本。在每次迭代中,它从当前类别的样本迭代器中取出指定数量的样本,并返回给调用者。当遍历完所有类别的样本后,它会重新随机打乱类别顺序,以便下一轮迭代时可以重新遍历样本。
class PriorityTree(object):
def __init__(self, capacity, init_weights, fixed_weights=None, fixed_scale=1.0,
alpha=1.0):
"""
fixed_weights: weights that wont be updated by self.update()
"""
assert fixed_weights is None or len(fixed_weights) == capacity
assert len(init_weights) == capacity
self.alpha = alpha
self._capacity = capacity
self._tree_size = 2 * capacity - 1
self.fixed_scale = fixed_scale
self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \
else fixed_weights
self.tree = np.zeros(self._tree_size)
self._initialized = False
self.initialize(init_weights)
def initialize(self, init_weights):
"""Initialize the tree."""
# Rescale the fixed_weights if it is not zero
self.fixed_scale_init = self.fixed_scale
if self.fixed_weights.sum() > 0 and init_weights.sum() > 0:
self.fixed_scale_init *= init_weights.sum() / self.fixed_weights.sum()
self.fixed_weights *= self.fixed_scale * init_weights.sum() \
/ self.fixed_weights.sum()
print('FixedWeights: {}'.format(self.fixed_weights.sum()))
self.update_whole(init_weights + self.fixed_weights)
self._initialized = True
def reset_adaptive_weights(self, adaptive_weights):
self.update_whole(self.fixed_weights + adaptive_weights)
def reset_fixed_weights(self, fixed_weights, rescale=False):
""" Reset the manually designed weights and
update the whole tree accordingly.
@rescale: rescale the fixed_weights such that
fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum()
"""
adaptive_weights = self.get_adaptive_weights()
fixed_sum = fixed_weights.sum()
if rescale and fixed_sum > 0:
# Rescale fixedweight based on adaptive weights
scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum
else:
# Rescale fixedweight based on previous fixedweight
scale = self.fixed_weights.sum() / fixed_sum
self.fixed_weights = fixed_weights * scale
self.update_whole(self.fixed_weights + adaptive_weights)
def update_whole(self, total_weights):
""" Update the whole tree based on per-example sampling weights """
if self.alpha != 1:
total_weights = np.power(total_weights, self.alpha)
lefti = self.pointer_to_treeidx(0)
righti = self.pointer_to_treeidx(self.capacity-1)
self.tree[lefti:righti+1] = total_weights
# Iteratively find a parent layer
while lefti != 0 and righti != 0:
lefti = (lefti - 1) // 2 if lefti != 0 else 0
righti = (righti - 1) // 2 if righti != 0 else 0
# Assign paraent weights from right to left
for i in range(righti, lefti-1, -1):
self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2]
def get_adaptive_weights(self):
""" Get the instance-aware weights, that are not mannually designed"""
if self.alpha == 1:
return self.get_total_weights() - self.fixed_weights
else:
return self.get_raw_total_weights() - self.fixed_weights
def get_total_weights(self):
""" Get the per-example sampling weights
return shape: [capacity]
"""
lefti = self.pointer_to_treeidx(0)
righti = self.pointer_to_treeidx(self.capacity-1)
return self.tree[lefti:righti+1]
def get_raw_total_weights(self):
""" Get the per-example sampling weights
return shape: [capacity]
"""
lefti = self.pointer_to_treeidx(0)
righti = self.pointer_to_treeidx(self.capacity-1)
return np.power(self.tree[lefti:righti+1], 1/self.alpha)
@property
def size(self):
return self._tree_size
@property
def capacity(self):
return self._capacity
def __len__(self):
return self.capacity
def pointer_to_treeidx(self, pointer):
assert pointer < self.capacity
return int(pointer + self.capacity - 1)
def update(self, pointer, priority):
assert pointer < self.capacity
tree_idx = self.pointer_to_treeidx(pointer)
priority += self.fixed_weights[pointer]
if self.alpha != 1:
priority = np.power(priority, self.alpha)
delta = priority - self.tree[tree_idx]
self.tree[tree_idx] = priority
while tree_idx != 0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += delta
def update_delta(self, pointer, delta):
assert pointer < self.capacity
tree_idx = self.pointer_to_treeidx(pointer)
ratio = 1- self.fixed_weights[pointer] / self.tree[tree_idx]
# delta *= ratio
if self.alpha != 1:
# Update delta
if self.tree[tree_idx] < 0 or \
np.power(self.tree[tree_idx], 1/self.alpha) + delta < 0:
import pdb; pdb.set_trace()
delta = np.power(np.power(self.tree[tree_idx], 1/self.alpha) + delta,
self.alpha) \
- self.tree[tree_idx]
self.tree[tree_idx] += delta
while tree_idx != 0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += delta
def get_leaf(self, value):
assert self._initialized, 'PriorityTree not initialized!!!!'
assert self.total > 0, 'No priority weights setted!!'
parent = 0
while True:
left_child = 2 * parent + 1
right_child = 2 * parent + 2
if left_child >= len(self.tree):
tgt_leaf = parent
break
if value < self.tree[left_child]:
parent = left_child
else:
value -= self.tree[left_child]
parent = right_child
data_idx = tgt_leaf - self.capacity + 1
return data_idx, self.tree[tgt_leaf] # data idx, priority
@property
def total(self):
assert self._initialized, 'PriorityTree not initialized!!!!'
return self.tree[0]
@property
def max(self):
return np.max(self.tree[-self.capacity:])
@property
def min(self):
assert self._initialized, 'PriorityTree not initialized!!!!'
return np.min(self.tree[-self.capacity:])
def get_weights(self):
wdict = {'fixed_weights': self.fixed_weights,
'total_weights': self.get_total_weights()}
if self.alpha != 1:
wdict.update({'raw_total_weights': self.get_raw_total_weights(),
'alpha': self.alpha})
return wdict