LONG-TAILED RECOGNITION 精读

BackGround

解决类别不平衡问题一般的思路:

  1. re-sample the data 重采样
  2. design specific loss functions that better facilitate learning with imbalanced data 设计针对不平衡数据的损失函数
  3. enhance recognition performance of the tail classes by transferring knowledge from the head classes 知识迁移
    • 解释一下:通过从头部类别中转移知识来提高尾部类别的识别性能。在机器学习和模式识别领域,通常存在一些类别的样本数量较少,被称为尾部类别;而另一些类别的样本数量较多,被称为头部类别。由于尾部类别样本数量的不足,其识别性能往往较低。因此,通过从头部类别中学习到的知识,可以帮助提高尾部类别的识别性能。这种方法通常被称为知识迁移。

下面直奔主题:decouple the learning procedure into representation learning and classification
即,解耦成两个部分:表征学习和分类器。

  • 表征学习部分:the model is exposed to the training instances and trained through different sampling strategies or losses.
  • 分类器部分:upon the learned representations, the model recognizes the long-tailed classes through various classifiers

采样策略

  1. standard instance-based sampling
    标准基于实例的采样(standard instance-based sampling)是一种常用的采样策略,用于解决类别不平衡问题。在类别不平衡的数据集中,某些类别的样本数量远远多于其他类别,这可能导致模型对于少数类别的识别性能较差。

    标准基于实例的采样通过对数据集中的样本进行重采样,以平衡不同类别的样本数量。具体而言,该采样策略会从多数类别中随机选择一定数量的样本,使其数量与少数类别相当。这样可以避免模型过于偏向多数类别,提高对少数类别的识别能力。

    标准基于实例的采样通常有两种方式:欠采样(undersampling)和过采样(oversampling)。

    • 欠采样:从多数类别中删除一些样本,使其数量与少数类别相当。这种方法可能会导致信息丢失,因为删除了一些多数类别的样本。但它可以减少多数类别的影响,使得模型更关注少数类别。
    • 过采样:通过复制或生成新的样本,使得多数类别的样本数量与少数类别相当。这种方法可能会导致过拟合,因为复制或生成的样本可能过于接近多数类别的分布。但它可以增加少数类别的样本数量,提高模型对少数类别的学习能力。
  2. class-balanced sampling
    类别平衡采样的目标是通过调整样本的采样权重,使得每个类别在采样过程中都有相似的机会被选中。这样可以避免模型过于偏向多数类别,提高对少数类别的学习能力。

  3. 以上两种方法混合

Class-balanced Losses

  1. Focal loss
  2. Meta-Weight-Net
  3. re-weighted training
  4. based on Bayesian uncertainty
  5. to balance the classification regions of head and tail classes using an affinity measure to enforce cluster centers of classes to be uniformly spaced and equidistant

Transfer learning from head- to tail classes

transferring features learned from head classes with abundant
training instances to under-represented tail classes

  1. transferring the intra-class variance
  2. transferring semantic deep features
  • However it is usually a non-trivial task to design specific modules (e.g. external memory) for feature transfer
  1. low-shot recognition,the setup for long-tail recognition assumes access to both head and tail classes and a more continuous decrease in in class labels
  2. adopt re-balancing schedules that learn representation and classifier jointly within a two-stage training scheme
  3. OLTR:uses instance-balanced sampling to first learn
    representations that are fine-tuned in a second stage with class-balanced sampling together with a
    memory module
  4. LDAM:a label-distribution-aware margin loss that
    expands the decision boundaries of few-shot classes.

分类器

  1. re-training the parametric linear classifier(预训练的线性分类器) in a class-balancing manner (i.e., re-sampling)
  2. non-parametric nearest class mean classifier(KNN), which classifies the data based on their closest class-specific mean representations from the training set
  3. normalizing the classifier weights(分类器权重进行归一化,加入温度参数), which adjusts the weight magnitude directly to be more balanced, adding a temperature to modulate the normalization procedure

作者的实验结论

  1. instance-balanced sampling learns the best and most generalizable representations.
  2. 重新调整由联合学习分类器指定的决策边界是有优势的,这可以通过表示学习期间重新训练分类器使用类别平衡采样,或者通过简单而有效的分类器权重归一化来实现。我们的实验结果表明,这两种方法都可以实现决策边界的调整,并且它们都只有一个超参数来控制“温度”,而且不需要额外的训练。
  3. By applying the decoupled learning scheme to standard networks (e.g., ResNeXt), we
    achieve significantly higher accuracy than well established state-of-the-art methods (different sampling strategies, new loss designs and other complex modules) on multiple longtailed recognition benchmark datasets, including ImageNet-LT, Places-LT, and iNaturalist

Start

采样策略上的更新:

  • 均方根采样
    在这里插入图片描述

  • 调节采样
    LONG-TAILED RECOGNITION 精读_第1张图片

  • Loss re-weighting strategies


  • Classifier Re-training (cRT)

  • Nearest Class Mean classifier (NCM)

  • τ -normalized classifier (τ -normalized)
    LONG-TAILED RECOGNITION 精读_第2张图片

  • Learnable weight scaling (LWS)
    LONG-TAILED RECOGNITION 精读_第3张图片


Extra Feature

  1. ResNet/ + modulatedatt -> x , featuremap
  • ResNet 残差网络就不提了
  • ResNext 则使用了类似的残差块,但在残差块内部引入了分组卷积(grouped convolution)的概念,通过增加卷积层的宽度来增加模型的表示能力。
  • ModulatedAttLayer

这个layer是一个模块化的注意力层,用于提取输入特征图的重要信息。它通过计算输入特征图的不同部分之间的关联性,以及对特征图进行空间注意力调整,来增强特征图中的有用信息。

  • 具体来说,该层包括三个卷积操作(g、theta、phi)和一个卷积操作(conv_mask)。它通过计算输入特征图的三个不同表示(g、theta、phi),并使用这些表示计算特征图之间的关联性(map_t_p),然后将关联性与输入特征图中的空间注意力调整(mask)相乘,得到最终的特征图输出(final)。
  • 该层还包括一个全连接层(fc_spatial),用于计算特征图的空间注意力权重(spatial_att)。最后,该层返回特征图输出和中间结果(x、spatial_att、mask)作为输出。
  • 在这个模块化的注意力层中,卷积层的kernel_size被设置为1是为了保持特征图的空间维度不变。这是因为在这个层中,主要关注的是特征图之间的关联性和空间注意力调整,而不是特征图的空间维度变化。

  • 因此,通过将kernel_size设置为1,可以确保卷积操作只对特征图的通道维度进行操作,而不改变其空间维度。这样可以保持特征图的空间结构不变,使得后续的操作能够更好地利用特征图中的信息。

import torch
from torch import nn
from torch.nn import functional as F
import pdb

class ModulatedAttLayer(nn.Module):

    def __init__(self, in_channels, reduction = 2, mode='embedded_gaussian'):
        super(ModulatedAttLayer, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        self.inter_channels = in_channels // reduction
        self.mode = mode
        assert mode in ['embedded_gaussian']

        self.g = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
        self.theta = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
        self.phi = nn.Conv2d(self.in_channels, self.inter_channels, kernel_size = 1)
        self.conv_mask = nn.Conv2d(self.inter_channels, self.in_channels, kernel_size = 1, bias=False)
        self.relu = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc_spatial = nn.Linear(7 * 7 * self.in_channels, 7 * 7)

        self.init_weights()

    def init_weights(self):
        msra_list = [self.g, self.theta, self.phi]
        for m in msra_list:
            nn.init.kaiming_normal_(m.weight.data)
            m.bias.data.zero_()
        self.conv_mask.weight.data.zero_()

    def embedded_gaussian(self, x):
        # embedded_gaussian cal self-attention, which may not strong enough
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        map_t_p = torch.matmul(theta_x, phi_x)
        mask_t_p = F.softmax(map_t_p, dim=-1)

        map_ = torch.matmul(mask_t_p, g_x)
        map_ = map_.permute(0, 2, 1).contiguous()
        map_ = map_.view(batch_size, self.inter_channels, x.size(2), x.size(3))
        mask = self.conv_mask(map_)
        
        x_flatten = x.view(-1, 7 * 7 * self.in_channels)

        spatial_att = self.fc_spatial(x_flatten)
        spatial_att = spatial_att.softmax(dim=1)
        
        spatial_att = spatial_att.view(-1, 7, 7).unsqueeze(1)
        spatial_att = spatial_att.expand(-1, self.in_channels, -1, -1)

        final = spatial_att * mask + x

        return final, [x, spatial_att, mask]

    def forward(self, x):
        if self.mode == 'embedded_gaussian':
            output, feature_maps = self.embedded_gaussian(x)
        else:
            raise NotImplemented("The code has not been implemented.")
        return output, feature_maps

Classifier

KNNClassifier

"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""


import torch
import torch.nn as nn
import numpy as np
import pickle
from os import path

class KNNClassifier(nn.Module):
    def __init__(self, feat_dim=512, num_classes=1000, feat_type='cl2n', dist_type='l2'):
        super(KNNClassifier, self).__init__()
        assert feat_type in ['un', 'l2n', 'cl2n'], "feat_type is wrong!!!"
        assert dist_type in ['l2', 'cos'], "dist_type is wrong!!!"
        self.feat_dim = feat_dim
        self.num_classes = num_classes
        self.centroids = torch.randn(num_classes, feat_dim)
        self.feat_mean = torch.randn(feat_dim)
        self.feat_type = feat_type
        self.dist_type = dist_type
        self.initialized = False
    
    def update(self, cfeats):
        mean = cfeats['mean']
        centroids = cfeats['{}cs'.format(self.feat_type)]

        mean = torch.from_numpy(mean)
        centroids = torch.from_numpy(centroids)
        self.feat_mean.copy_(mean)
        self.centroids.copy_(centroids)
        if torch.cuda.is_available():
            self.feat_mean = self.feat_mean.cuda()
            self.centroids = self.centroids.cuda()
        self.initialized = True

    def forward(self, inputs, *args):
        centroids = self.centroids
        feat_mean = self.feat_mean

        # Feature transforms
        if self.feat_type == 'cl2n':
            inputs = inputs - feat_mean
            #centroids = centroids - self.feat_mean

        if self.feat_type in ['l2n', 'cl2n']:
            norm_x = torch.norm(inputs, 2, 1, keepdim=True)
            inputs = inputs / norm_x

            #norm_c = torch.norm(centroids, 2, 1, keepdim=True)
            #centroids = centroids / norm_c
        
        # Logit calculation
        if self.dist_type == 'l2':
            logit = self.l2_similarity(inputs, centroids)
        elif self.dist_type == 'cos':
            logit = self.cos_similarity(inputs, centroids)
        
        return logit, None

    def l2_similarity(self, A, B):
        # input A: [bs, fd] (batch_size x feat_dim)
        # input B: [nC, fd] (num_classes x feat_dim)
        feat_dim = A.size(1)

        AB = torch.mm(A, B.t())
        AA = (A**2).sum(dim=1, keepdim=True)
        BB = (B**2).sum(dim=1, keepdim=True)
        dist = AA + BB.t() - 2*AB

        return -dist
    
    def cos_similarity(self, A, B):
        feat_dim = A.size(1)
        AB = torch.mm(A, B.t())
        AB = AB / feat_dim
        return AB


def create_model(feat_dim, num_classes=1000, feat_type='cl2n', dist_type='l2',
                 log_dir=None, test=False, *args):
    print('Loading KNN Classifier')
    print(feat_dim, num_classes, feat_type, dist_type, log_dir, test)
    clf = KNNClassifier(feat_dim, num_classes, feat_type, dist_type)

    if log_dir is not None:
        fname = path.join(log_dir, 'cfeats.pkl')
        if path.exists(fname):
            print('===> Loading features from %s' % fname)
            with open(fname, 'rb') as f:
                data = pickle.load(f)
            clf.update(data)
    else:
        print('Random initialized classifier weights.')
    
    return clf


if __name__ == "__main__":
    cens = np.eye(4)
    mean = np.ones(4)
    xs = np.array([
        [0.9, 0.1, 0.0, 0.0],
        [0.2, 0.1, 0.1, 0.6],
        [0.3, 0.3, 0.4, 0.0],
        [0.0, 1.0, 0.0, 0.0],
        [0.25, 0.25, 0.25, 0.25]
    ])
    xs = torch.Tensor(xs)

    classifier = KNNClassifier(feat_dim=4, num_classes=4, 
                               feat_type='un')
    classifier.update(mean, cens)
    import pdb; pdb.set_trace()
    logits, _ = classifier(xs)

Dot_Classifier

"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""


import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

from utils import *
from os import path

class DotProduct_Classifier(nn.Module):
    
    def __init__(self, num_classes=1000, feat_dim=2048, *args):
        super(DotProduct_Classifier, self).__init__()
        # print('<DotProductClassifier> contains bias: {}'.format(bias))
        self.fc = nn.Linear(feat_dim, num_classes)
        self.scales = Parameter(torch.ones(num_classes))
        for param_name, param in self.fc.named_parameters():
            param.requires_grad = False
        
    def forward(self, x, *args):
        x = self.fc(x)
        x *= self.scales
        return x, None
    
def create_model(feat_dim, num_classes=1000, stage1_weights=False, dataset=None, log_dir=None, test=False, *args):
    print('Loading Dot Product Classifier.')
    clf = DotProduct_Classifier(num_classes, feat_dim)

    if not test:
        if stage1_weights:
            assert(dataset)
            print('Loading %s Stage 1 Classifier Weights.' % dataset)
            if log_dir is not None:
                subdir = log_dir.strip('/').split('/')[-1]
                subdir = subdir.replace('stage2', 'stage1')
                weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir)
                # weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1')
            else:
                weight_dir = './logs/%s/stage1' % dataset
            print('==> Loading classifier weights from %s' % weight_dir)
            clf.fc = init_weights(model=clf.fc,
                                  weights_path=path.join(weight_dir, 'final_model_checkpoint.pth'),
                                  classifier=True)
        else:
            print('Random initialized classifier weights.')

    return clf

Loss

DiscCentroidsLoss

class DiscCentroidsLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, size_average=True):
        super(DiscCentroidsLoss, self).__init__()
        self.num_classes = num_classes
        self.centroids = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.disccentroidslossfunc = DiscCentroidsLossFunc.apply
        self.feat_dim = feat_dim
        self.size_average = size_average

    def forward(self, feat, label):
        batch_size = feat.size(0)
        
        # calculate attracting loss

        feat = feat.view(batch_size, -1)
        # To check the dim of centroids and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's \
                            dim: {1}".format(self.feat_dim,feat.size(1)))
        batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
        loss_attract = self.disccentroidslossfunc(feat, label, self.centroids, batch_size_tensor).squeeze()
        
        # calculate repelling loss

        distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centroids, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, feat, self.centroids.t())

        classes = torch.arange(self.num_classes).long().cuda()
        labels_expand = label.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels_expand.eq(classes.expand(batch_size, self.num_classes))

        distmat_neg = distmat
        distmat_neg[mask] = 0.0
        # margin = 50.0
        margin = 10.0
        loss_repel = torch.clamp(margin - distmat_neg.sum() / (batch_size * self.num_classes), 0.0, 1e6)

        # loss = loss_attract + 0.05 * loss_repel
        loss = loss_attract + 0.01 * loss_repel

        return loss
 
 
class DiscCentroidsLossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label, centroids, batch_size):
        ctx.save_for_backward(feature, label, centroids, batch_size)
        centroids_batch = centroids.index_select(0, label.long())
        return (feature - centroids_batch).pow(2).sum() / 2.0 / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centroids, batch_size = ctx.saved_tensors
        centroids_batch = centroids.index_select(0, label.long())
        diff = centroids_batch - feature
        # init every iteration
        counts = centroids.new_ones(centroids.size(0))
        ones = centroids.new_ones(label.size(0))
        grad_centroids = centroids.new_zeros(centroids.size())

        counts = counts.scatter_add_(0, label.long(), ones)
        grad_centroids.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centroids = grad_centroids/counts.view(-1, 1)
        return - grad_output * diff / batch_size, None, grad_centroids / batch_size, None


    
def create_loss (feat_dim=512, num_classes=1000):
    print('Loading Discriminative Centroids Loss.')
    return DiscCentroidsLoss(num_classes, feat_dim)
import torch.nn as nn

def create_loss ():
    print('Loading Softmax Loss.')
    return nn.CrossEntropyLoss()

采样

  • 循环采样

循环遍历不同类别的样本,并按照指定数量从每个类别中取样本。它使用了两个迭代器,一个用于循环遍历不同类别的样本,另一个用于循环遍历每个类别中的样本。在每次迭代中,它从当前类别的样本迭代器中取出指定数量的样本,并返回给调用者。当遍历完所有类别的样本后,它会重新随机打乱类别顺序,以便下一轮迭代时可以重新遍历样本。

  • 样本重要性优先采样
class PriorityTree(object):
    def __init__(self, capacity, init_weights, fixed_weights=None, fixed_scale=1.0,
                 alpha=1.0):
        """
        fixed_weights: weights that wont be updated by self.update()
        """
        assert fixed_weights is None or len(fixed_weights) == capacity
        assert len(init_weights) == capacity
        self.alpha = alpha
        self._capacity = capacity
        self._tree_size = 2 * capacity - 1
        self.fixed_scale = fixed_scale
        self.fixed_weights = np.zeros(self._capacity) if fixed_weights is None \
                             else fixed_weights
        self.tree = np.zeros(self._tree_size)
        self._initialized = False
        self.initialize(init_weights)

    def initialize(self, init_weights):
        """Initialize the tree."""

        # Rescale the fixed_weights if it is not zero
        self.fixed_scale_init = self.fixed_scale
        if self.fixed_weights.sum() > 0 and init_weights.sum() > 0:
            self.fixed_scale_init *= init_weights.sum() / self.fixed_weights.sum()
            self.fixed_weights *= self.fixed_scale * init_weights.sum() \
                                / self.fixed_weights.sum()
        print('FixedWeights: {}'.format(self.fixed_weights.sum()))

        self.update_whole(init_weights + self.fixed_weights)
        self._initialized = True
    
    def reset_adaptive_weights(self, adaptive_weights):
        self.update_whole(self.fixed_weights + adaptive_weights)
    
    def reset_fixed_weights(self, fixed_weights, rescale=False):
        """ Reset the manually designed weights and 
            update the whole tree accordingly.

            @rescale: rescale the fixed_weights such that 
            fixed_weights.sum() = self.fixed_scale * adaptive_weights.sum()
        """

        adaptive_weights = self.get_adaptive_weights()
        fixed_sum = fixed_weights.sum()
        if rescale and fixed_sum > 0:
            # Rescale fixedweight based on adaptive weights
            scale = self.fixed_scale * adaptive_weights.sum() / fixed_sum
        else:
            # Rescale fixedweight based on previous fixedweight
            scale = self.fixed_weights.sum() / fixed_sum
        self.fixed_weights = fixed_weights * scale
        self.update_whole(self.fixed_weights + adaptive_weights)
    
    def update_whole(self, total_weights):
        """ Update the whole tree based on per-example sampling weights """
        if self.alpha != 1:
            total_weights = np.power(total_weights, self.alpha)
        lefti = self.pointer_to_treeidx(0)
        righti = self.pointer_to_treeidx(self.capacity-1)
        self.tree[lefti:righti+1] = total_weights

        # Iteratively find a parent layer
        while lefti != 0 and righti != 0:
            lefti = (lefti - 1) // 2 if lefti != 0 else 0
            righti = (righti - 1) // 2 if righti != 0 else 0
            
            # Assign paraent weights from right to left
            for i in range(righti, lefti-1, -1):
                self.tree[i] = self.tree[2*i+1] + self.tree[2*i+2]
    
    def get_adaptive_weights(self):
        """ Get the instance-aware weights, that are not mannually designed"""
        if self.alpha == 1:
            return self.get_total_weights() - self.fixed_weights
        else:
            return self.get_raw_total_weights() - self.fixed_weights
    
    def get_total_weights(self):
        """ Get the per-example sampling weights
            return shape: [capacity]
        """
        lefti = self.pointer_to_treeidx(0)
        righti = self.pointer_to_treeidx(self.capacity-1)
        return self.tree[lefti:righti+1]

    def get_raw_total_weights(self):
        """ Get the per-example sampling weights
            return shape: [capacity]
        """
        lefti = self.pointer_to_treeidx(0)
        righti = self.pointer_to_treeidx(self.capacity-1)
        return np.power(self.tree[lefti:righti+1], 1/self.alpha)

    @property
    def size(self):
        return self._tree_size

    @property
    def capacity(self):
        return self._capacity

    def __len__(self):
        return self.capacity

    def pointer_to_treeidx(self, pointer):
        assert pointer < self.capacity
        return int(pointer + self.capacity - 1)

    def update(self, pointer, priority):
        assert pointer < self.capacity
        tree_idx = self.pointer_to_treeidx(pointer)
        priority += self.fixed_weights[pointer]
        if self.alpha != 1:
            priority = np.power(priority, self.alpha)
        delta = priority - self.tree[tree_idx]
        self.tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += delta
    
    def update_delta(self, pointer, delta):
        assert pointer < self.capacity
        tree_idx = self.pointer_to_treeidx(pointer)
        ratio = 1- self.fixed_weights[pointer] / self.tree[tree_idx]
        # delta *= ratio
        if self.alpha != 1:
            # Update delta
            if self.tree[tree_idx] < 0 or \
                np.power(self.tree[tree_idx], 1/self.alpha) + delta < 0:
                import pdb; pdb.set_trace()
            delta = np.power(np.power(self.tree[tree_idx], 1/self.alpha) + delta,
                             self.alpha) \
                  - self.tree[tree_idx]
        self.tree[tree_idx] += delta
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += delta

    def get_leaf(self, value):
        assert self._initialized, 'PriorityTree not initialized!!!!'
        assert self.total > 0, 'No priority weights setted!!'
        parent = 0
        while True:
            left_child = 2 * parent + 1
            right_child = 2 * parent + 2
            if left_child >= len(self.tree):
                tgt_leaf = parent
                break
            if value < self.tree[left_child]:
                parent = left_child
            else:
                value -= self.tree[left_child]
                parent = right_child
        data_idx = tgt_leaf - self.capacity + 1
        return data_idx, self.tree[tgt_leaf]        # data idx, priority

    @property
    def total(self):
        assert self._initialized, 'PriorityTree not initialized!!!!'
        return self.tree[0]

    @property
    def max(self):
        return np.max(self.tree[-self.capacity:])

    @property
    def min(self):
        assert self._initialized, 'PriorityTree not initialized!!!!'
        return np.min(self.tree[-self.capacity:])
    
    def get_weights(self):
        wdict = {'fixed_weights': self.fixed_weights, 
                 'total_weights': self.get_total_weights()}
        if self.alpha != 1:
            wdict.update({'raw_total_weights': self.get_raw_total_weights(),
                          'alpha': self.alpha})

        return wdict

你可能感兴趣的:(机器学习,人工智能)