下面直奔主题:decouple the learning procedure into representation learning and classification
standard instance-based sampling
标准基于实例的采样(standard instance-based sampling)是一种常用的采样策略,用于解决类别不平衡问题。在类别不平衡的数据集中,某些类别的样本数量远远多于其他类别,这可能导致模型对于少数类别的识别性能较差。
class-balanced sampling
transferring features learned from head classes with abundant
training instances to under-represented tail classes
Classifier Re-training (cRT)
Nearest Class Mean classifier (NCM)
- 具体来说,该层包括三个卷积操作(g、theta、phi)和一个卷积操作(conv_mask)。它通过计算输入特征图的三个不同表示(g、theta、phi),并使用这些表示计算特征图之间的关联性(map_t_p),然后将关联性与输入特征图中的空间注意力调整(mask)相乘,得到最终的特征图输出(final)。
- 该层还包括一个全连接层(fc_spatial),用于计算特征图的空间注意力权重(spatial_att)。最后,该层返回特征图输出和中间结果(x、spatial_att、mask)作为输出。
import torch
from torch import nn
from torch.nn import functional as F
import pdb
class ModulatedAttLayer(nn.Module):
def __init__(self, in_channels, reduction = 2, mode='embedded_gaussian'):
super(ModulatedAttLayer, self).__init__()
self.in_channels = in_channels
self.reduction = reduction
self.inter_channels = in_channels // reduction
self.mode = mode
assert mode in ['embedded_gaussian']
self.g = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
self.theta = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
self.phi = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
self.conv_mask = nn.Conv2d(self.inter_channels, self.in_channels, kernel_size = 1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc_spatial = nn.Linear(7 * 7 * self.in_channels, 7 * 7)
def init_weights(self):
msra_list = [self.g, self.theta, self.phi]
for m in msra_list:
def embedded_gaussian(self, x):
# embedded_gaussian cal self-attention, which may not strong enough
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
map_t_p = torch.matmul(theta_x, phi_x)
mask_t_p = F.softmax(map_t_p, dim=-1)
map_ = torch.matmul(mask_t_p, g_x)
map_ = map_.permute(0, 2, 1).contiguous()
map_ = map_.view(batch_size, self.inter_channels, x.size(2), x.size(3))
mask = self.conv_mask(map_)
x_flatten = x.view(-1, 7 * 7 * self.in_channels)
spatial_att = self.fc_spatial(x_flatten)
spatial_att = spatial_att.softmax(dim=1)
spatial_att = spatial_att.view(-1, 7, 7).unsqueeze(1)
spatial_att = spatial_att.expand(-1, self.in_channels, -1, -1)
final = spatial_att * mask + x
return final, [x, spatial_att, mask]
def forward(self, x):
if self.mode == 'embedded_gaussian':
output, feature_maps = self.embedded_gaussian(x)
raise NotImplemented("The code has not been implemented.")
return output, feature_maps
"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import numpy as np
import pickle
from os import path
class KNNClassifier(nn.Module):
def __init__(self, feat_dim=512, num_classes=1000, feat_type='cl2n', dist_type='l2'):
super(KNNClassifier, self).__init__()
assert feat_type in ['un', 'l2n', 'cl2n'], "feat_type is wrong!!!"
assert dist_type in ['l2', 'cos'], "dist_type is wrong!!!"
self.feat_dim = feat_dim
self.num_classes = num_classes
self.centroids = torch.randn(num_classes, feat_dim)
self.feat_mean = torch.randn(feat_dim)
self.feat_type = feat_type
self.dist_type = dist_type
self.initialized = False
def update(self, cfeats):
mean = cfeats['mean']
centroids = cfeats['{}cs'.format(self.feat_type)]
mean = torch.from_numpy(mean)
centroids = torch.from_numpy(centroids)
if torch.cuda.is_available():
self.feat_mean = self.feat_mean.cuda()
self.centroids = self.centroids.cuda()
self.initialized = True
def forward(self, inputs, *args):
centroids = self.centroids
feat_mean = self.feat_mean
# Feature transforms
if self.feat_type == 'cl2n':
inputs = inputs - feat_mean
#centroids = centroids - self.feat_mean
if self.feat_type in ['l2n', 'cl2n']:
norm_x = torch.norm(inputs, 2, 1, keepdim=True)
inputs = inputs / norm_x
#norm_c = torch.norm(centroids, 2, 1, keepdim=True)
#centroids = centroids / norm_c
# Logit calculation
if self.dist_type == 'l2':
logit = self.l2_similarity(inputs, centroids)
elif self.dist_type == 'cos':
logit = self.cos_similarity(inputs, centroids)
return logit, None
def l2_similarity(self, A, B):
# input A: [bs, fd] (batch_size x feat_dim)
# input B: [nC, fd] (num_classes x feat_dim)
feat_dim = A.size(1)
AB = torch.mm(A, B.t())
AA = (A**2).sum(dim=1, keepdim=True)
BB = (B**2).sum(dim=1, keepdim=True)
dist = AA + BB.t() - 2*AB
return -dist
def cos_similarity(self, A, B):
feat_dim = A.size(1)
AB = torch.mm(A, B.t())
AB = AB / feat_dim
return AB
def create_model(feat_dim, num_classes=1000, feat_type='cl2n', dist_type='l2',
log_dir=None, test=False, *args):
print('Loading KNN Classifier')
print(feat_dim, num_classes, feat_type, dist_type, log_dir, test)
clf = KNNClassifier(feat_dim, num_classes, feat_type, dist_type)
if log_dir is not None:
fname = path.join(log_dir, 'cfeats.pkl')
if path.exists(fname):
print('===> Loading features from %s' % fname)
with open(fname, 'rb') as f:
data = pickle.load(f)
print('Random initialized classifier weights.')
return clf
if __name__ == "__main__":
cens = np.eye(4)
mean = np.ones(4)
xs = np.array([
[0.9, 0.1, 0.0, 0.0],
[0.2, 0.1, 0.1, 0.6],
[0.3, 0.3, 0.4, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.25, 0.25, 0.25, 0.25]
xs = torch.Tensor(xs)
classifier = KNNClassifier(feat_dim=4, num_classes=4,
classifier.update(mean, cens)
import pdb; pdb.set_trace()
logits, _ = classifier(xs)
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from utils import *
from os import path
class DotProduct_Classifier(nn.Module):
def __init__(self, num_classes=1000, feat_dim=2048, *args):
super(DotProduct_Classifier, self).__init__()
# print('<DotProductClassifier> contains bias: {}'.format(bias))
self.fc = nn.Linear(feat_dim, num_classes)
self.scales = Parameter(torch.ones(num_classes))
for param_name, param in self.fc.named_parameters():
param.requires_grad = False
def forward(self, x, *args):
x = self.fc(x)
x *= self.scales
return x, None
def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, *args):
print('Loading Dot Product Classifier.')
clf = DotProduct_Classifier(num_classes, feat_dim)
if not test:
if stage1_weights:
print('Loading %s Stage 1 Classifier Weights.' % dataset)
if log_dir is not None:
subdir = log_dir.strip('/').split('/')[-1]
subdir = subdir.replace('stage2', 'stage1')
weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir)
# weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1')
weight_dir = './logs/%s/stage1' % dataset
print('==> Loading classifier weights from %s' % weight_dir)
clf.fc = init_weights(model=clf.fc,
weights_path=path.join(weight_dir, 'final_model_checkpoint.pth'),
print('Random initialized classifier weights.')
return clf
class DiscCentroidsLoss(nn.Module):
def __init__(self, num_classes, feat_dim, size_average=True):
super(DiscCentroidsLoss, self).__init__()
self.num_classes = num_classes
self.centroids = nn.Parameter(torch.randn(num_classes, feat_dim))
self.disccentroidslossfunc = DiscCentroidsLossFunc.apply
self.feat_dim = feat_dim
self.size_average = size_average
def forward(self, feat, label):
batch_size = feat.size(0)
# calculate attracting loss
feat = feat.view(batch_size, -1)
# To check the dim of centroids and features
if feat.size(1) != self.feat_dim:
raise ValueError("Center's dim: {0} should be equal to input feature's \
dim: {1}".format(self.feat_dim,feat.size(1)))
batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
loss_attract = self.disccentroidslossfunc(feat, label, self.centroids, batch_size_tensor).squeeze()
# calculate repelling loss
distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centroids, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, feat, self.centroids.t())
classes = torch.arange(self.num_classes).long().cuda()
labels_expand = label.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels_expand.eq(classes.expand(batch_size, self.num_classes))
distmat_neg = distmat
distmat_neg[mask] = 0.0
# margin = 50.0
margin = 10.0
loss_repel = torch.clamp(margin - distmat_neg.sum() / (batch_size * self.num_classes), 0.0, 1e6)
# loss = loss_attract + 0.05 * loss_repel
loss = loss_attract + 0.01 * loss_repel
return loss
class DiscCentroidsLossFunc(Function):
def forward(ctx, feature, label, centroids, batch_size):
ctx.save_for_backward(feature, label, centroids, batch_size)
centroids_batch = centroids.index_select(0, label.long())
return (feature - centroids_batch).pow(2).sum() / 2.0 / batch_size
def backward(ctx, grad_output):
feature, label, centroids, batch_size = ctx.saved_tensors
centroids_batch = centroids.index_select(0, label.long())
diff = centroids_batch - feature
# init every iteration
counts = centroids.new_ones(centroids.size(0))
ones = centroids.new_ones(label.size(0))
grad_centroids = centroids.new_zeros(centroids.size())
counts = counts.scatter_add_(0, label.long(), ones)
grad_centroids.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
grad_centroids = grad_centroids/counts.view(-1, 1)
return - grad_output * diff / batch_size, None, grad_centroids / batch_size, None
def create_loss (feat_dim=512, num_classes=1000):
print('Loading Discriminative Centroids Loss.')
return DiscCentroidsLoss(num_classes, feat_dim)
import torch.nn as nn
def create_loss ():
print('Loading Softmax Loss.')
return nn.CrossEntropyLoss()
class PriorityTree(object):
def __init__(self, capacity, init_weights, fixed_weights=None, fixed_scale=1.0,
fixed_weights: weights that wont be updated by self.update()
assert fixed_weights is None or len(fixed_weights) == capacity
assert len(init_weights) == capacity
self.alpha = alpha
self._capacity = capacity
self._tree_size = 2 * capacity - 1
self.fixed_scale = fixed_scale
self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \
else fixed_weights
self.tree = np.zeros(self._tree_size)
self._initialized = False
def initialize(self, init_weights):
"""Initialize the tree."""
# Rescale the fixed_weights if it is not zero
self.fixed_scale_init = self.fixed_scale
if self.fixed_weights.sum() > 0 and init_weights.sum() > 0:
self.fixed_scale_init *= init_weights.sum() / self.fixed_weights.sum()
self.fixed_weights *= self.fixed_scale * init_weights.sum() \
/ self.fixed_weights.sum()
print('FixedWeights: {}'.format(self.fixed_weights.sum()))
self.update_whole(init_weights + self.fixed_weights)
self._initialized = True
def reset_adaptive_weights(self, adaptive_weights):
self.update_whole(self.fixed_weights + adaptive_weights)
def reset_fixed_weights(self, fixed_weights, rescale=False):
""" Reset the manually designed weights and
update the whole tree accordingly.
@rescale: rescale the fixed_weights such that
fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum()
adaptive_weights = self.get_adaptive_weights()
fixed_sum = fixed_weights.sum()
if rescale and fixed_sum > 0:
# Rescale fixedweight based on adaptive weights
scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum
# Rescale fixedweight based on previous fixedweight
scale = self.fixed_weights.sum() / fixed_sum
self.fixed_weights = fixed_weights * scale
self.update_whole(self.fixed_weights + adaptive_weights)
def update_whole(self, total_weights):
""" Update the whole tree based on per-example sampling weights """
if self.alpha != 1:
total_weights = np.power(total_weights, self.alpha)
lefti = self.pointer_to_treeidx(0)
righti = self.pointer_to_treeidx(self.capacity-1)
self.tree[lefti:righti+1] = total_weights
# Iteratively find a parent layer
while lefti != 0 and righti != 0:
lefti = (lefti - 1) // 2 if lefti != 0 else 0
righti = (righti - 1) // 2 if righti != 0 else 0
# Assign paraent weights from right to left
for i in range(righti, lefti-1, -1):
self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2]
def get_adaptive_weights(self):
""" Get the instance-aware weights, that are not mannually designed"""
if self.alpha == 1:
return self.get_total_weights() - self.fixed_weights
return self.get_raw_total_weights() - self.fixed_weights
def get_total_weights(self):
""" Get the per-example sampling weights
return shape: [capacity]
lefti = self.pointer_to_treeidx(0)
righti = self.pointer_to_treeidx(self.capacity-1)
return self.tree[lefti:righti+1]
def get_raw_total_weights(self):
""" Get the per-example sampling weights
return shape: [capacity]
lefti = self.pointer_to_treeidx(0)
righti = self.pointer_to_treeidx(self.capacity-1)
return np.power(self.tree[lefti:righti+1], 1/self.alpha)
def size(self):
return self._tree_size
def capacity(self):
return self._capacity
def __len__(self):
return self.capacity
def pointer_to_treeidx(self, pointer):
assert pointer < self.capacity
return int(pointer + self.capacity - 1)
def update(self, pointer, priority):
assert pointer < self.capacity
tree_idx = self.pointer_to_treeidx(pointer)
priority += self.fixed_weights[pointer]
if self.alpha != 1:
priority = np.power(priority, self.alpha)
delta = priority - self.tree[tree_idx]
self.tree[tree_idx] = priority
while tree_idx != 0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += delta
def update_delta(self, pointer, delta):
assert pointer < self.capacity
tree_idx = self.pointer_to_treeidx(pointer)
ratio = 1- self.fixed_weights[pointer] / self.tree[tree_idx]
# delta *= ratio
if self.alpha != 1:
# Update delta
if self.tree[tree_idx] < 0 or \
np.power(self.tree[tree_idx], 1/self.alpha) + delta < 0:
import pdb; pdb.set_trace()
delta = np.power(np.power(self.tree[tree_idx], 1/self.alpha) + delta,
self.alpha) \
- self.tree[tree_idx]
self.tree[tree_idx] += delta
while tree_idx != 0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += delta
def get_leaf(self, value):
assert self._initialized, 'PriorityTree not initialized!!!!'
assert self.total > 0, 'No priority weights setted!!'
parent = 0
while True:
left_child = 2 * parent + 1
right_child = 2 * parent + 2
if left_child >= len(self.tree):
tgt_leaf = parent
if value < self.tree[left_child]:
parent = left_child
value -= self.tree[left_child]
parent = right_child
data_idx = tgt_leaf - self.capacity + 1
return data_idx, self.tree[tgt_leaf] # data idx, priority
def total(self):
assert self._initialized, 'PriorityTree not initialized!!!!'
return self.tree[0]
def max(self):
return np.max(self.tree[-self.capacity:])
def min(self):
assert self._initialized, 'PriorityTree not initialized!!!!'
return np.min(self.tree[-self.capacity:])
def get_weights(self):
wdict = {'fixed_weights': self.fixed_weights,
'total_weights': self.get_total_weights()}
if self.alpha != 1:
wdict.update({'raw_total_weights': self.get_raw_total_weights(),
'alpha': self.alpha})
return wdict