Compacting, Picking and Growing for Unforgetting Continual Learning 论文及代码流程解读

文章目录

  • 论文翻译
    • Abstract
    • 1 Introduction
      • 方法设计的动机
      • Method Overview
    • 2 Related Work
    • 3 The CPG approach for Continual Lifelong Learning
  • 实验1复现
    • 1.baseline:VGG16
    • 2.CPG_cifar100_scratch_mul_1.5.sh
      • Task1
        • finetune mode
        • gradually pruning
        • Choose the checkpoint
      • Task2 (k>1)
        • finetune mode
        • gradually pruning
        • Retrain piggymask and weight
        • Retrain piggymask and weight
      • Growing

论文翻译

Abstract

我们的方法利用了深度模型压缩、关键权重选择和渐进网络扩展(deep model compression, critical weights selection,and progressive networks expansion)。
通过以迭代的方式加强它们的集成,我们引入了一种增量学习方法,该方法可扩展到连续学习过程中的连续任务数。
我们的方法有如下几个优点:

  1. 它可以避免遗忘(在记住以前所有任务的同时学习新任务)
  2. 它允许模型扩展,但可以在处理顺序任务时保持模型的紧凑性(compactness)
  3. 通过我们的压缩和选择/扩展机制,我们证明了通过学习以前的任务积累的知识有助于为新任务建立更好的模型,而不是单独训练模型。

1 Introduction

虽然学习的模型可以用作预训练的模型,但针对新任务对模型进行微调会迫使模型参数拟合新数据,从而导致对之前任务的灾难性遗忘。

  • 为了减轻灾难性遗忘的影响,Kirkpatrick et al. 和 Zenkeet al. 研究了训练期间利用梯度或权重的正则化的技术。该算法对网络权值进行了正则化,并希望能对当前任务和以前的任务搜索一个共同收敛的算法。
  • Schwarz et al.提出了一种用于正则化的网络蒸馏方法,该方法对自适应于教师网络的神经权值施加约束,并将弹性-权值巩固(elastic-weight-consolidation,EWC)[14]应用于增量式训练

然而,由于在学习过程中缺少以往任务的训练数据,且网络容量是固定的(且有限的),正则化方法往往会逐渐忘记学习到的技能。

为了解决数据丢失的问题(即在原有任务训练数据不足的情况下,引入了数据保存和记忆重放技术。Data-preserving:直接保存重要数据或潜在编码作为一种有效的形式
Memory-replay:引入额外的内存模型,如GANs,以间接的方式保持数据信息或分布。内存模型有能力重放以前的数据。基于过去的数据信息,我们可以训练一个性能在很大程度上能够满足旧任务要求的模型。
然而,Memory-replay的一个普遍问题是,它们需要使用积累的旧信息进行明确的再训练,这会导致工作记忆要么很大,要么在记忆和遗忘之间妥协。

本文介绍了一种学习可持续性而紧凑的深度模型的方法,该方法可以在避免遗忘的情况下处理无限数量的连续任务。由于一个有限的架构不能确保记住从无限的任务中逐步学到的技能,我们的方法允许架构在某种程度上成长。然而,在连续学习过程中,我们也消除了模型冗余,从而可以在有限的模型扩展下,不断地压缩多个任务。
此外,从一项起始任务开始进行预训练或逐步微调模型,只包含初始化时的先验知识;因此,知识库会随着过去的任务而减少。
由于人类有能力在一生中不断地获取、调整和传递知识和技能,因此,在终身学习中,我们希望从以前工作中积累的经验有助于学习新的任务。通过使用我们的方法,越来越多的学习到的模型可以作为一个紧凑的、不遗忘的基础,它通常会比独立训练任务产生更好的后续任务模型。实验结果表明,我们的终身学习方法可以利用过去积累的知识来提高新任务的表现。

方法设计的动机

ProgressiveNet 保持原有参数不动的情况下为新的任务训练新的参数,这样会形成严重的参数冗余,所以我们不断地对模型进行剪枝。
在我们不断发展的CPG方法中,提供了两种可能的选择。

  1. 为新任务使用之前释放的权重。如果当所有释放的权重都被使用时,性能目标还没有实现,那么我们接着进行第二种选择
  2. 扩展架构,释放的和扩展的权重都被用于新任务训练。

我们的方法的另一个区别是挑选步骤。这个想法的动机如下。在ProgressiveNet中,旧任务的权重都被保留下来,用于学习新任务。然而,随着任务数量的增加,旧任务的权重也会越来越大。当它们都与在成长阶段新增加的权重共同使用时,旧的权重(固定的)就像惯性一样,因为只有更少的新权重被允许调整,这往往会降低学习效果。为了解决这个问题,我们没有使用所有的旧任务权重,而是通过一个可区分的掩码从它们中选择一些关键的权重。
压缩任务权值的一个主要难点是缺少先验知识来确定剪枝率。为了解决这个问题,在我们的CPG方法的压缩(compacting)步骤中,我们使用了逐步修剪过程,删除了一小部分权重,并对剩余的权重重新进行训练,以迭代地恢复性能。当达到预先定义的精度目标时,该过程停止。请注意,只有新添加的权重(从生长步骤中释放和/或扩展的权重)才允许修剪,而旧的任务权重保持不变。

Method Overview

  1. 通过剪枝建立一个压缩模型。给定一个新任务,旧任务模型的权重也是固定的。
  2. 通过一个可区分的掩码来挑选和重用一些对新任务至关重要的旧任务权重,并使用之前释放的权重一起进行学习。
  3. 如果精度目标还没有达到,架构可以通过在模型中添加过滤器(filters)或节点并恢复过程来扩展。然后重复这个过程。
    Compacting, Picking and Growing for Unforgetting Continual Learning 论文及代码流程解读_第1张图片

新的任务权重由两部分组成:第一部分是通过对旧任务权重的可学习的掩码来选取,第二部分是通过对额外权重的逐步剪枝/再训练来学习。由于旧任务的权重是固定的,所以我们可以将所需的函数映射集成到一个紧凑的模型中,而不会影响其推理的准确性。我们的方法的主要特点概括如下。

  • 避免遗忘:我们的方法确保不遗忘。在增量地添加新任务时,将按照完全相同的方式维护以前构建的函数映射
  • 收缩的同时扩展:我们的方法允许扩展,但保持了架构的紧凑性,可以潜在地处理无限的顺序任务。实验结果表明,多任务可以被压缩到一个只有少量或没有架构增长的模型中。
  • 紧凑知识库:实验结果表明,将之前任务记录的压缩模型作为知识库,积累了经验,可用于选择权重,提高了学习新任务的性能。

2 Related Work

Continual lifelong learning可分为三大类:网络正则化、记忆或数据回放和动态架构。
另外,关于无任务和作为程序合成(programsynthesis)的工作也在最近被研究。在下面,我们简要回顾了主要类别的作品,并建议读者参考最近的调查论文[28]进行更多的研究。
网络正则化方法:核心思想是限制地更新已学习模型的权值。为了保留学习到的任务信息,对权值的变化进行了惩罚。EWC利用Fisher信息来评估旧任务权重的重要性,并根据重要性的程度更新权重。
费雪信息 (Fisher information) 的直观意义是什么?
[49]中的方法基于相似的思想,通过学习轨迹来计算重要性。在线EWC[40]和ewc++改善EWC的效率问题。[6]给出了一种信息保留惩罚。该方法建立了一个注意力图,并希望前一个模型和并发模型的注意力区域是一致的。这些工作在一定程度上缓解了非稳态遗忘,但不能保证前期工作的准确性。
记忆重放:记忆或数据重放方法[32,41,13,3,46,45,11,34,33,27]使用额外的模型来记住数据信息。生成重放[41]将GANs引入终身学习。它利用一个生成器对与之前数据分布相似的假数据进行采样。可以使用这些生成的数据训练新的任务。记忆重放GANs (MeRGANs)[45]表明,生成器中仍然存在遗忘现象,生成的数据在未来任务中性能会变差。他们使用重放数据来提高生成器的质量。动态生成记忆(Dynamic generate Memory, DGM)[27]利用神经掩蔽来学习条件生成模型中的连接可塑性,并在生成器中为顺序任务设置动态扩展机制。虽然这些方法可以利用数据信息,但仍然不能保证过去任务的准确执行。
动态结构:动态结构方法[38,20,36,29,48]用一系列的任务调整结构。ProgressiveNet[38]为新任务扩展了架构,并通过保留以前的权重来保持函数映射。LwF[20]将模型层划分为共享的和特定于任务的两个层次,其中前者由任务共同使用,后者通过进一步的分支扩展以用于新的任务。DAN[36]扩展了每个新任务的架构,而新任务模型中的每一层都是基模型对应层中原始过滤器的稀疏线性组合。最近在GANs上采用的记忆重放方法[27]也采用了架构扩展。这些方法通过结构扩展可以显著减少或避免灾难性遗忘,但模型是单调递增的,会产生冗余结构。
随着架构的不断增长,将保留模型冗余,一些方法在扩展[48]之前执行模型压缩,以便可以构建一个紧凑的模型。过去与我们相关的方法主要是动态扩展网(DEN)[48]。DEN通过稀疏正则化的方法减少了之前任务的权重。新添加的权值和旧的权值都能适应带有稀疏约束的新任务。然而,DEN并不能确保不遗忘。当旧任务权值和新任务权值联合训练时,会选择并修改部分旧任务权值。因此,我们引入了一个“分割和复制”的步骤来进一步恢复一些为减少遗忘效果而修改过的旧权值。Pack and Expand (PAE)[12]是我们之前利用PackNet[23]和ProgressiveNet[38]的方法。它可以避免遗忘,保持模型的紧凑性,允许动态的模型扩展。但是,由于它使用了之前任务的所有权重进行共享,所以在学习新任务时,性能会变得不太好。
我们的方法(CPG)是通过一个压缩挑选(增长)循环来完成的,它从旧任务中选择权重,而不修改它们,从而避免遗忘。此外,我们的方法不需要像DEN那样恢复旧任务的性能,因为性能已经保持,从而避免了繁琐的“分割和复制”过程,需要额外的时间调整模型,也会影响新任务的性能。因此,我们的方法简单且更容易实现。实验结果表明,该方法的性能也优于DEN和PAE。

3 The CPG approach for Continual Lifelong Learning

为了不失一般性,我们的工作遵循基于任务的顺序学习设置,这是持续学习中的常见设置。在下面,我们以顺序任务的方式展示我们的方法。

T a s k   1 Task ~1 Task 1
给定第一个任务( T a s k   1 Task ~1 Task 1)和通过其数据集训练的初始模型,我们对模型执行逐步剪枝[51],以去除冗余,同时保持性能。逐渐的修剪删除了部分权重,并训练模型迭代恢复性能,直到满足修剪标准,而不是一次修剪权重到修剪比率目标。因此,我们压缩当前模型,以便删除(或释放)模型权重中的冗余,然后将紧凑模型中的权重设置为不变且保持不变,以避免遗忘。逐步修剪后,模型权重可分为两部分:第一部分为任务1保留;另一部分被释放,可以被随后的任务所使用。

T a s k   k − k + 1 Task~ k-k+1 Task kk+1
假设在 T a s k   k Task~ k Task k 中,已经构建了一个可以处理 T a s k   1 − k Task~ 1-k Task 1k 的压缩模型。为 T a s k   1 − k Task~ 1-k Task 1k 保留的模型权重记为 W 1 : k P W^P_{1:k} W1:kP ,与任务k相关的释放(冗余)权重表示为 W k E W^E_{k} WkE 它们是可以用于后续任务的额外权重。给定 T a s k   k + 1 Task~ k+1 Task k+1 的数据集,我们应用一个可学习的掩码 M M M 来提取旧的权重 W 1 : k P W^P_{1:k} W1:kP M ∈ { 0 , 1 } D M∈\{0,1\}^D M{0,1}D 其中 D D D 的维数是 W 1 : k P W^P_{1:k} W1:kP 。然后表示被选择的权重 M ⨀ W 1 : k P M\bigodot W^P_{1:k} MW1:kP 在不丧失一般性的情况下,我们使用背驮式(piggyback)方法[22],该方法学习实值掩码并应用阈值进行二值化构造 M M M。因此,给定一个新任务,我们通过一个可学习的掩码从被压缩的模型中选择一组权重(已知的临界权重)。此外我们将在新任务中使用 W k E W^E_{k} WkE 。掩码 M M M 和额外的权重 W k E W^E_{k} WkE 是在 T a s k   k + 1 Task~ k+1 Task k+1 的训练数据上通过反向传播在 T a s k   k + 1 Task~ k+1 Task k+1 的损失函数上一起学习的。由于二值化的掩码不可微,在训练二值化的掩码 M M M 时,我们在后向过程中更新实值掩码 M ^ \hat{M} M^。然后 M M M 通过一个 M ^ \hat{M} M^ 上的阈值进行量化并应用到前向计算。如果性能还不满意,模型架构可以增加更多的训练权重,也就是说, W k E W^E_{k} WkE 可以被增加额外的权重(比如卷积层中的新过滤器和全连接层中的节点),然后恢复对 M M M W k E W^E_{k} WkE 的训练。注意,在训练期间,掩码 M M M 和新的权重 W k E W^E_{k} WkE 可以被调整,但原来的权重 W 1 : k P W^P_{1:k} W1:kP 只有被 pick 的权重参与训练,其他的被固定。这样,旧的任务就可以被准确地召回。

T a s k   k + 1 Task~ k+1 Task k+1的压缩:
经过 M M M W k E W^E_{k} WkE 的学习,得到了 T a s k   k + 1 Task~k+1 Task k+1 的初始模型。然后,我们固定掩码 M M M 并对 W k E W^E_{k} WkE 进行逐步剪,从而得到 T a s k   k + 1 Task~k+1 Task k+1 的压缩模型 W k + 1 P W^P_{k+1} Wk+1P 和冗余(被释放)的权重 W k + 1 E W^E_{k+1} Wk+1E 旧任务的压缩模型随后变成 W 1 : ( k + 1 ) P = W 1 : k P ∪ W k + 1 P W^P_{1:(k+1)}=W^P_{1:k} \cup W^P_{k+1} W1:(k+1)P=W1:kPWk+1P 从一个任务到另一个任务的压缩和选择/扩展循环是重复的。
Compacting, Picking and Growing for Unforgetting Continual Learning 论文及代码流程解读_第2张图片
实验部分就不翻译了。

实验1复现

下载数据集解压到/data目录下后

1.baseline:VGG16

分为baseline和finetune两个版本
baseline:独立对20个任务进行训练
finetune:每次随机从已训练的任务模型中选择一个进行微调。

    if [ "$task_id" != "1" ]
    then
        echo "Start training task " $task_id
        python TOOLS/random_generate_task_id.py --curr_task_id $task_id
        initial_from_task_id=$?  # 返回上一指令的返回值
        echo "Initial for curr task is " $initial_from_task_id
        CUDA_VISIBLE_DEVICES=$GPU_ID python packnet_cifar100_main_normal.py \
            --arch $arch \
            --dataset ${dataset[task_id]} --num_classes 5 \
            --lr 5e-3 \
            --weight_decay 4e-5 \
            --save_folder checkpoints/finetune/experiment1/$arch/${dataset[task_id]} \
            --epochs $finetune_epochs \
            --mode finetune \
            --logfile logs/finetune_cifar100_acc_normal_5e-3.txt \
            --initial_from_task checkpoints/finetune/experiment1/$arch/${dataset[initial_from_task_id]}

那么这部分最关键的也就是参数MASK的实现,看论文的时候就很好奇。
看到代码之后方才觉得自己的编程基础还不扎实啊。
思路很简单,对于模型需要更改参数的结构定义对应shape的可训练tensor备用即可。

    if not masks:
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if 'classifiers' in name:
                    continue
                mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                if 'cuda' in module.weight.data.type():
                    mask = mask.cuda()
                masks[name] = mask

Compacting, Picking and Growing for Unforgetting Continual Learning 论文及代码流程解读_第3张图片
以下部分的代码在跑baseline的时候也会运行,但效果应是和单纯跑VGG16是一样的。
我们第一次看主要是留个印象,后续还会再出现,届时再做延伸分析。

Train部分代码
第一次或者单独训练每个任务的时候看不出来,但这里也是独立出一部分classifiers参数备用

    for tuple_ in named_params.items():
        if 'classifiers' in tuple_[0]:
            if '.{}.'.format(model.module.datasets.index(args.dataset)) in tuple_[0]:
                params_to_optimize_via_SGD.append(tuple_[1])
                named_params_to_optimize_via_SGD.append(tuple_)                
            continue
        else:
            params_to_optimize_via_SGD.append(tuple_[1])
            named_params_to_optimize_via_SGD.append(tuple_)

不使用框架自带的正则化,因为我们不想改变所有参数

    # here we must set weight decay to 0.0, 
    # because the weight decay strategy in build-in step() function will change every weight elem in the tensor,
    # which will hurt previous tasks' accuracy. (Instead, we do weight decay ourself in the `prune.py`)
    optimizer_network = optim.SGD(params_to_optimize_via_SGD, lr=lr,
                          weight_decay=0.0, momentum=0.9, nesterov=True)  
    if args.mode == 'prune':
        print()
        print('Sparsity ratio: {}'.format(args.one_shot_prune_perc))
        print('Before pruning: ')
        baseline_acc = manager.validate(start_epoch-1)
        print('Execute one shot pruning ...')
        manager.one_shot_prune(args.one_shot_prune_perc)
    elif args.mode == 'finetune':
        manager.pruner.make_finetuning_mask()

finetuning的时候将剪枝的权重设为可训练(具体调整看后续)

    def make_finetuning_mask(self):
        """Turns previously pruned weights into trainable weights for
           current dataset.
        """
        assert self.masks
        self.current_dataset_idx += 1

        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if 'classifiers' in name:
                    continue
                mask = self.masks[name]
                mask[mask.eq(0)] = self.current_dataset_idx

train的每个batch反向传播后

				# Set fixed param grads to 0.
                self.pruner.do_weight_decay_and_make_grads_zero()
    def do_weight_decay_and_make_grads_zero(self):
        """Sets grads of fixed weights to 0."""
        assert self.masks
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if 'classifiers' in name:
                    continue
                mask = self.masks[name]
                # Set grads of all weights not belonging to current dataset to 0.
                if module.weight.grad is not None:
                    / 手动正则化
                    module.weight.grad.data.add_(self.args.weight_decay, module.weight.data)
                    / 不是本数据集的参数不进行梯度更新0
                    module.weight.grad.data[mask.ne(
                        self.current_dataset_idx)] = 0
module.weight.grad.data.add_(self.args.weight_decay, module.weight.data)

这句相当于手动正则化 计算逻辑为
A . a d d _ ( B , C ) = A + B ∗ C A.add\_(B,C) = A + B * C A.add_(B,C)=A+BC 结果存入A中

                # Set pruned weights to 0.
                self.pruner.make_pruned_zero()

被剪枝的权重置0

    def make_pruned_zero(self):
        """Makes pruned weights 0."""
        assert self.masks

        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if 'classifiers' in name:
                    continue
                layer_mask = self.masks[name]
                module.weight.data[layer_mask.eq(0)] = 0.0

Val部分代码

    def validate(self, epoch_idx, biases=None):
        """Performs evaluation."""
        self.pruner.apply_mask()
        ...
        ...

已经固定的参数和idx大于本任务的权重置0

    def apply_mask(self):
        """To be done to retrieve weights just for a particular dataset."""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                if 'classifiers' in name:
                    continue
                weight = module.weight.data
                mask = self.masks[name].cuda()
                weight[mask.eq(0)] = 0.0
                weight[mask.gt(self.inference_dataset_idx)] = 0.0

运行结束后会在log文件夹下产生baseline的准确率,具体如下,之后会用到

{"aquatic_mammals": "0.6280", "fish": "0.7560", "flowers": "0.7700", "food_containers": "0.8220", "fruit_and_vegetables": "0.8520", "household_electrical_devices": "0.8360", "household_furniture": "0.7880", "insects": "0.8440", "large_carnivores": "0.8300", "large_man-made_outdoor_things": "0.8800", "large_natural_outdoor_scenes": "0.8900", "large_omnivores_and_herbivores": "0.8100", "medium_mammals": "0.8380", "non-insect_invertebrates": "0.7960", "people": "0.4880", "reptiles": "0.7100", "small_mammals": "0.6840", "trees": "0.7120", "vehicles_1": "0.8860", "vehicles_2": "0.9160"}

2.CPG_cifar100_scratch_mul_1.5.sh

Task1

finetune mode

对应于,当然任务1是没有模型用于微调的,后续任务会在Task k-1上进行finetune

        else
            CUDA_VISIBLE_DEVICES=$GPU_ID python CPG_cifar100_main_normal.py \
                --arch $arch \
                --dataset ${dataset[task_id]} --num_classes $num_classes \
                --lr $lr \
                --lr_mask $lr_mask \
                --batch_size $batch_size \
                --weight_decay 4e-5 \
                --save_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/scratch \
                --epochs $finetune_epochs \
                --mode finetune \
                --network_width_multiplier $network_width_multiplier \
                --max_allowed_network_width_multiplier $max_allowed_network_width_multiplier \
                --baseline_acc_file $baseline_cifar100_acc \
                --pruning_ratio_to_acc_record_file checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune/record.txt \
                --log_path checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/train.log \
                --total_num_tasks $total_num_tasks
        fi

为什么要做开根号处理?

    # Don't use this, neither set learning rate as a linear function
    # of the count of gpus, it will make accuracy lower
    # args.batch_size = args.batch_size * torch.cuda.device_count()
    # 1^(1/2)
    args.network_width_multiplier = math.sqrt(args.network_width_multiplier) 
    # 1.5^(1/2)
    args.max_allowed_network_width_multiplier = math.sqrt(args.max_allowed_network_width_multiplier) 

模型定义的时候

class VGG(nn.Module):
    def __init__(self, features, dataset_history, dataset2num_classes, network_width_multiplier=1.0, shared_layer_info={}, init_weights=True, progressive_init=False):
        super(VGG, self).__init__()
        self.features = features
        self.network_width_multiplier = network_width_multiplier
        self.shared_layer_info = shared_layer_info
        # self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        
        self.datasets, self.classifiers = dataset_history, nn.ModuleList()
        self.dataset2num_classes = dataset2num_classes

        if self.datasets:
            self._reconstruct_classifiers()

        if init_weights:
            self._initialize_weights()

        if progressive_init:
            self._initialize_weights_2()

也就是涉及扩容的倍数

    def _reconstruct_classifiers(self):
        for dataset, num_classes in self.dataset2num_classes.items():
            self.classifiers.append(nn.Linear(int(self.shared_layer_info[dataset]['network_width_multiplier'] * 4096), num_classes))

接下来就来看看网络是如何定义的,本来想懒一下不分层的,但是还是忍了。
0    C P G _ c i f a r 100 _ m a i n _ n o r m a l . p y 0\ ~CPG\_cifar100\_main\_normal.py 0  CPG_cifar100_main_normal.py

        model = models.__dict__[args.arch](custom_cfg, dataset_history=dataset_history, dataset2num_classes=dataset2num_classes,
            network_width_multiplier=args.network_width_multiplier, shared_layer_info=shared_layer_info)

0   C P G _ c i f a r 100 _ m a i n _ n o r m a l . p y → 1   c u s t o m _ v g g _ c i f a r 100 0\ CPG\_cifar100\_main\_normal.py \rightarrow 1\ custom\_vgg\_cifar100 0 CPG_cifar100_main_normal.py1 custom_vgg_cifar100

def custom_vgg_cifar100(custom_cfg, dataset_history=[], dataset2num_classes={}, network_width_multiplier=1.0, groups=1, shared_layer_info={}, **kwargs):
    return VGG(make_layers_cifar100(custom_cfg, network_width_multiplier, batch_norm=True, groups=groups), dataset_history, 
        dataset2num_classes, network_width_multiplier, shared_layer_info, **kwargs)

1   c u s t o m _ v g g _ c i f a r 100 → 2   m a k e _ l a y e r s _ c i f a r 100 1\ custom\_vgg\_cifar100 \rightarrow 2\ make\_layers\_cifar100 1 custom_vgg_cifar1002 make_layers_cifar100

def make_layers_cifar100(cfg, network_width_multiplier, batch_norm=False, groups=1):
    layers = []
    in_channels = 3

    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            if in_channels == 3:
                conv2d = nl.SharableConv2d(in_channels, int(v * network_width_multiplier), kernel_size=3, padding=1, bias=False)
            else:
                conv2d = nl.SharableConv2d(in_channels, int(v * network_width_multiplier), kernel_size=3, padding=1, bias=False, groups=groups)

            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(int(v * network_width_multiplier)), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = int(v * network_width_multiplier)

    layers += [
        View(-1, int(512*network_width_multiplier)),
        nl.SharableLinear(int(512*network_width_multiplier), int(4096*network_width_multiplier)),
        nn.ReLU(True),
        nl.SharableLinear(int(4096*network_width_multiplier), int(4096*network_width_multiplier)),
        nn.ReLU(True),
    ]

    return nn.Sequential(*layers)

2   m a k e _ l a y e r s _ c i f a r 100 → 3   S h a r a b l e C o n v 2 d 2\ make\_layers\_cifar100 \rightarrow 3\ SharableConv2d 2 make_layers_cifar1003 SharableConv2d

class SharableConv2d(nn.Module):
    """Modified conv with masks for weights."""

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True,
                 mask_init='1s', mask_scale=1e-2,
                 threshold_fn='binarizer', threshold=None):
        super(SharableConv2d, self).__init__()
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        self.mask_scale = mask_scale
        self.mask_init = mask_init

        if threshold is None:
            threshold = DEFAULT_THRESHOLD  / 定义为 0.005
        self.info = {
            'threshold_fn': threshold_fn,
            'threshold': threshold,
        }

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = False
        self.output_padding = _pair(0)
        self.groups = groups

        
        self.weight = Parameter(torch.Tensor(
            out_channels, in_channels // groups, *kernel_size), requires_grad=True)
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels), requires_grad=True)
        else:
            self.register_parameter('bias', None)

        # Give real-valued mask weights per task to manage the shared part from previous tasks.
        self.piggymask = None

        # Initialize the thresholder.
        / 这个threshold_fn是梯度更新时的设置,可能需要再细究
        if threshold_fn == 'binarizer':
            # print('Calling binarizer with threshold:', threshold)
            self.threshold_fn = Binarizer.apply
        elif threshold_fn == 'ternarizer':
            print('Calling ternarizer with threshold:', threshold)
            self.threshold_fn = Ternarizer(threshold=threshold)
def _ntuple(n):
    def parse(x):
        if isinstance(x, container_abcs.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse

_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)

简单解释一下,以kernel_size为例

{int} 3  ->  {tuple:2} (3, 3)
class Binarizer(torch.autograd.Function):
    """Binarizes {0, 1} a real valued tensor."""

    @staticmethod
    def forward(ctx, inputs, threshold):
        outputs = inputs.clone()
        outputs[inputs.le(threshold)] = 0
        outputs[inputs.gt(threshold)] = 1
        return outputs

    @staticmethod
    def backward(ctx, grad_out):
        return grad_out, None
class SharableLinear(nn.Module):
    """Modified linear layer."""

    def __init__(self, in_features, out_features, bias=True,
                 mask_init='1s', mask_scale=1e-2,
                 threshold_fn='binarizer', threshold=None):
        super(SharableLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.threshold_fn = threshold_fn
        self.mask_scale = mask_scale
        self.mask_init = mask_init

        if threshold is None:
            threshold = DEFAULT_THRESHOLD
        self.info = {
            'threshold_fn': threshold_fn,
            'threshold': threshold,
        }

        # weight and bias are no longer Parameters.
        self.weight = Parameter(torch.Tensor(
            out_features, in_features), requires_grad=True)
        if bias:
            self.bias = Parameter(torch.Tensor(
                out_features), requires_grad=True)
        else:
            self.register_parameter('bias', None)

        self.piggymask = None

        # Initialize the thresholder.
        if threshold_fn == 'binarizer':
            self.threshold_fn = Binarizer.apply
        elif threshold_fn == 'ternarizer':
            self.threshold_fn = Ternarizer(threshold=threshold)

0    C P G _ c i f a r 100 _ m a i n _ n o r m a l . p y 0\ ~CPG\_cifar100\_main\_normal.py 0  CPG_cifar100_main_normal.py

    # update all layers
    named_params = dict(model.named_parameters())
    params_to_optimize_via_SGD = []
    named_of_params_to_optimize_via_SGD = []
    masks_to_optimize_via_Adam = []
    named_of_masks_to_optimize_via_Adam = []
    
    for name, param in named_params.items():
        if 'classifiers' in name:
            if '.{}.'.format(model.module.datasets.index(args.dataset)) in name:
                params_to_optimize_via_SGD.append(param)
                named_of_params_to_optimize_via_SGD.append(name)
            continue
        elif 'piggymask' in name:
            masks_to_optimize_via_Adam.append(param)
            named_of_masks_to_optimize_via_Adam.append(name)
        else:
            params_to_optimize_via_SGD.append(param)
            named_of_params_to_optimize_via_SGD.append(name)

注意也要为masks设置优化器,当然第一次训练的时候没有这个部分。

    if masks_to_optimize_via_Adam:
        optimizer_mask = optim.Adam(masks_to_optimize_via_Adam, lr=lr_mask)
        optimizers.add(optimizer_mask, lr_mask)
    elif args.mode == 'finetune':
        if not args.finetune_again:
            manager.pruner.make_finetuning_mask()

mask初始化为全0,也就是此时所有的mask置为1,即在第一次训练中使用全部参数来训练。

    def make_finetuning_mask(self):
        """Turns previously pruned weights into trainable weights for
           current dataset.
        """
        assert self.masks
        self.current_dataset_idx += 1

        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                mask = self.masks[name]
                mask[mask.eq(0)] = self.current_dataset_idx

计算压缩比例

        if manager.pruner.calculate_curr_task_ratio() == 0.0:
            logging.info('There is no left space in convolutional layer for curr task'
                  ', we will try to use prior experience as long as possible')
            stop_lr_mask = False
    def calculate_curr_task_ratio(self):
        total_elem = 0
        curr_task_elem = 0
        is_first_conv = True

        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):

                # if is_first_conv:
                #     is_first_conv = False
                #     continue

                mask = self.masks[name]
                total_elem += mask.numel()
                curr_task_elem += torch.sum(mask.eq(self.inference_dataset_idx))

                # break  # because every layer has the same pruning ratio,
                #        # so we are able to see only one layer for getting the sparsity

        return float(curr_task_elem.cpu()) / total_elem * (self.args.network_width_multiplier ** 2)

此次train过程中:

                # Set fixed param grads to 0.
                self.pruner.do_weight_decay_and_make_grads_zero()

注意留心piggymask,本次没有用到,此时该函数的作用仅为只更新本次任务相关的参数。

    def do_weight_decay_and_make_grads_zero(self):
        """Sets grads of fixed weights to 0."""
        assert self.masks
        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                mask = self.masks[name]
                # Set grads of all weights not belonging to current dataset to 0.
                if module.weight.grad is not None:
                    module.weight.grad.data.add_(self.args.weight_decay, module.weight.data)
                    module.weight.grad.data[mask.ne(
                        self.current_dataset_idx)] = 0
                if module.piggymask is not None and module.piggymask.grad is not None:
                    if self.args.mode == 'finetune':
                        module.piggymask.grad.data[mask.eq(0) | mask.ge(self.current_dataset_idx)] = 0
                    elif self.args.mode == 'prune':
                        module.piggymask.grad.data.fill_(0)
        return

另外计算稀疏程度 zero_elem/total_elem

    def calculate_sparsity(self):
        total_elem = 0
        zero_elem = 0
        is_first_conv = True

        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):

                # if is_first_conv:
                #     is_first_conv = False
                #     continue

                mask = self.masks[name]
                total_elem += torch.sum(mask.eq(self.inference_dataset_idx) | mask.eq(0))
                zero_elem += torch.sum(mask.eq(0))

                # total_elem += torch.sum(mask.ge(self.inference_dataset_idx) | mask.eq(0))
                # zero_elem += torch.sum(mask.eq(self.inference_dataset_idx))
                # break  # because every layer has the same pruning ratio,
                #        # so we are able to see only one layer for getting the sparsity

        if total_elem.cpu() != 0.0:
            return float(zero_elem.cpu()) / float(total_elem.cpu())
        else:
            return 0.0

来分析一下训练结束后的逻辑:

  1. 训练集准确率>0.95保存模型
  2. 如果是fintune模式且没有test_piggymask
    [0.0] = 验证集平均准确率
    1. 若 训练集准确率>0.95 且 验证集准确率>baseline 跳过
    2. 若模型宽度达到预设极限 且 验证集准确率
    3. 如果当前任务参数占总参数的0%,输出5
    4. 否则输出6
  3. 输出2,准备扩展模型

         ~~~~~~~~         如果当前任务参数占总参数的0%,输出5

如果输出为2:宽度+0.5

        if [ $state -eq 2 ]
        then
            network_width_multiplier=$(bc <<< $network_width_multiplier+0.5)
            echo "New network_width_multiplier: $network_width_multiplier"
            continue

输出不等5:进行压缩 (还有空间)

    / 训练集准确率>0.95保存模型
    if avg_train_acc > 0.95:
        manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)

    logging.info('-' * 16)

    if args.pruning_ratio_to_acc_record_file:
        json_data = {}
        if os.path.isfile(args.pruning_ratio_to_acc_record_file):
            with open(args.pruning_ratio_to_acc_record_file, 'r') as json_file:
                json_data = json.load(json_file)

        if args.mode == 'finetune' and not args.test_piggymask:
            json_data[0.0] = round(avg_val_acc, 4)
            with open(args.pruning_ratio_to_acc_record_file, 'w') as json_file:
                json.dump(json_data, json_file)
            if avg_train_acc > 0.95 and avg_val_acc >= baseline_acc:
                pass
            elif args.network_width_multiplier == args.max_allowed_network_width_multiplier and avg_val_acc < baseline_acc:
                if manager.pruner.calculate_curr_task_ratio() == 0.0:
                    sys.exit(5)
                else:
                    sys.exit(0)
            else:
                logging.info("It's time to expand the Network")
                logging.info('Auto expand network')
                sys.exit(2)

            if manager.pruner.calculate_curr_task_ratio() == 0.0:
                logging.info('There is no left space in convolutional layer for curr task, so needless to prune')
                sys.exit(5)

        elif args.mode == 'prune':
            if avg_train_acc > 0.95:
                json_data[args.target_sparsity] = round(avg_val_acc, 4)
                with open(args.pruning_ratio_to_acc_record_file, 'w') as json_file:
                    json.dump(json_data, json_file)
            else:
                sys.exit(6)

            must_pruning_ratio_for_curr_task = 0.0

            if args.network_width_multiplier == args.max_allowed_network_width_multiplier and json_data['0.0'] < baseline_acc:
                # If we reach the upperbound and still do not get the accuracy over our target on curr task, we still do pruning
                logging.info('we reach the upperbound and still do not get the accuracy over our target on curr task')
                remain_num_tasks = args.total_num_tasks - len(dataset_history)
                logging.info('remain_num_tasks: {}'.format(remain_num_tasks))
                ratio_allow_for_curr_task = round(1.0 / (remain_num_tasks + 1), 1)
                logging.info('ratio_allow_for_curr_task: {:.4f}'.format(ratio_allow_for_curr_task))
                must_pruning_ratio_for_curr_task = 1.0 - ratio_allow_for_curr_task
                if args.target_sparsity >= must_pruning_ratio_for_curr_task:
                    sys.exit(6)

gradually pruning

    nrof_epoch=0
    nrof_epoch_for_each_prune=20
    start_sparsity=0.0
    end_sparsity=0.1
    nrof_epoch=$nrof_epoch_for_each_prune

    # Prune the model after training
    if [ $state -ne 5 ]
    then
        echo $state
        # gradually pruning
        CUDA_VISIBLE_DEVICES=$GPU_ID python CPG_cifar100_main_normal.py \
            --arch $arch \
            --dataset ${dataset[task_id]} --num_classes $num_classes \
            --lr $gradual_prune_lr \
            --lr_mask 0.0 \
            --batch_size $batch_size \
            --weight_decay 4e-5 \
            --save_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune \
            --load_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/scratch \
            --epochs $nrof_epoch \
            --mode prune \
            --initial_sparsity=$start_sparsity \
            --target_sparsity=$end_sparsity \
            --pruning_frequency=10 \
            --pruning_interval=4 \
            --baseline_acc_file $baseline_cifar100_acc \
            --network_width_multiplier $network_width_multiplier \
            --max_allowed_network_width_multiplier $max_allowed_network_width_multiplier \
            --pruning_ratio_to_acc_record_file checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune/record.txt \
            --log_path checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/train.log \
            --total_num_tasks $total_num_tasks

Task1第一步finetune(其实是直接从0训练)后的准确率低于baseline,和论文中一样。
即只要没有出现判定条件,则均进行压缩。

参数设置上的变化:

lr -> gradual_prune_lr
lr_mask -> 0.0
更改 save folder
增加 load folder
epochs -> nrof_epoch  (20)
mode -> prune
--initial_sparsity 0.0
--target_sparsity 0.1
--pruning_frequency 10 
--pruning_interval 4

第一次以后要创建新文件夹

    if args.mode == 'prune':
        args.save_folder = os.path.join(args.save_folder, str(args.target_sparsity))
        if args.initial_sparsity != 0.0:
            args.load_folder = os.path.join(args.load_folder, str(args.initial_sparsity))

这次从100个epoch的断点恢复

    if resume_from_epoch:
        filepath = args.checkpoint_format.format(save_folder=resume_folder, epoch=resume_from_epoch)
        checkpoint = torch.load(filepath)
        checkpoint_keys = checkpoint.keys()
        dataset_history = checkpoint['dataset_history']
        dataset2num_classes = checkpoint['dataset2num_classes']
        masks = checkpoint['masks']
        shared_layer_info = checkpoint['shared_layer_info']
        #shared_layer_info[args.dataset]['network_width_multiplier'] = 1.0
        if 'num_for_construct' in checkpoint_keys:
            num_for_construct = checkpoint['num_for_construct']
        if args.mode == 'inference' and 'network_width_multiplier' in shared_layer_info[args.dataset]: # TODO, temporary solution
            args.network_width_multiplier = shared_layer_info[args.dataset]['network_width_multiplier']

如果模型有扩展,需要对应改变mask。本轮中暂不涉及。

        # when we expand network, we need to allocate new masks
        NEED_ADJUST_MASK = False
        for name, module in model.named_modules():
            if isinstance(module, nl.SharableConv2d):
                if masks[name].size(1) < module.weight.data.size(1):
                    assert args.mode == 'finetune'
                    NEED_ADJUST_MASK = True
                elif masks[name].size(1) > module.weight.data.size(1):
                    assert args.mode == 'inference'
                    NEED_ADJUST_MASK = True
	/ 如果之前没有训练过该任务
	if args.dataset not in shared_layer_info: 
		...
		...
    elif args.finetune_again:
    	...
    	...    
    else:
        #try:
        piggymasks = shared_layer_info[args.dataset]['piggymask']
        #except:
        #    piggymasks = {}
        task_id = model.module.datasets.index(args.dataset) + 1
        / 暂时 piggymasks 为空
        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                    module.piggymask = piggymasks[name]
    if args.mode == 'prune':
        if 'gradual_prune' in args.load_folder and args.save_folder == args.load_folder:
            args.epochs = 20 + resume_from_epoch
        logging.info('')
        logging.info('Before pruning: ')
        logging.info('Sparsity range: {} -> {}'.format(args.initial_sparsity, args.target_sparsity))

如果扩容到了上限还没有达到baseline进行提示
        ~~~~~~~        如果没有压缩空间,则输出6,否则正常运行

        if args.network_width_multiplier == args.max_allowed_network_width_multiplier and json_data['0.0'] < baseline_acc:
            # If we reach the upperbound and still do not get the accuracy over our target on curr task, we still do pruning
            logging.info('we reach the upperbound and still do not get the accuracy over our target on curr task')
            remain_num_tasks = args.total_num_tasks - len(dataset_history)
            logging.info('remain_num_tasks: {}'.format(remain_num_tasks))
            ratio_allow_for_curr_task = round(1.0 / (remain_num_tasks + 1), 1)
            logging.info('ratio_allow_for_curr_task: {:.4f}'.format(ratio_allow_for_curr_task))
            must_pruning_ratio_for_curr_task = 1.0 - ratio_allow_for_curr_task
            if args.initial_sparsity >= must_pruning_ratio_for_curr_task:
                sys.exit(6)

然后train又到了我们熟悉的环节:

                # Set fixed param grads to 0.
                self.pruner.do_weight_decay_and_make_grads_zero()
    def do_weight_decay_and_make_grads_zero(self):
        """Sets grads of fixed weights to 0."""
        assert self.masks
        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                mask = self.masks[name]
                # Set grads of all weights not belonging to current dataset to 0.
                if module.weight.grad is not None:
                    module.weight.grad.data.add_(self.args.weight_decay, module.weight.data)
                    module.weight.grad.data[mask.ne(
                        self.current_dataset_idx)] = 0
                if module.piggymask is not None and module.piggymask.grad is not None:
                    if self.args.mode == 'finetune':
                        module.piggymask.grad.data[mask.eq(0) | mask.ge(self.current_dataset_idx)] = 0
                    elif self.args.mode == 'prune':
                        module.piggymask.grad.data.fill_(0)
        return

然后piggymask部分也暂时还未用到,但我们也可以分析一下:
如果piggymask不为空且有梯度更新:

  • finetune:已经固定的参数和 >= 本次任务 的 piggymask不更新
  • prune:piggymask完全不更新梯度
                # Set pruned weights to 0.
                if self.args.mode == 'prune':
                    self.pruner.gradually_prune(curr_prune_step)
                    curr_prune_step += 1
    def gradually_prune(self, curr_prune_step):

        if self._time_to_update_masks(curr_prune_step):
            self.last_prune_step = curr_prune_step
            curr_pruning_ratio = self._adjust_sparsity(curr_prune_step)

            for name, module in self.model.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                    mask = self._pruning_mask(module.weight.data, self.masks[name], name, pruning_ratio=curr_pruning_ratio)
                    self.masks[name] = mask
                    # module.weight.data[self.masks[name].eq(0)] = 0.0
        else:
            curr_pruning_ratio = self._adjust_sparsity(self.last_prune_step)

        return curr_pruning_ratio

确认下是否需要进行mask更新

    def _time_to_update_masks(self, curr_prune_step):
        is_step_within_pruning_range = \
            (curr_prune_step >= self.begin_prune_step) and \
            (curr_prune_step <= self.end_prune_step)

        is_pruning_step = (
            self.last_prune_step + self.args.pruning_frequency) <= curr_prune_step

        return is_step_within_pruning_range and is_pruning_step

计算参数稀疏程度,增长形式非线性,有没有道理有多大道理暂时不明。

    def _adjust_sparsity(self, curr_prune_step):

        p = min(1.0,
                max(0.0,
                    ((curr_prune_step - self.begin_prune_step)
                    / (self.end_prune_step - self.begin_prune_step))
                ))

        sparsity = self.args.target_sparsity + \
            (self.args.initial_sparsity - self.args.target_sparsity) * pow(1-p, self.sparsity_func_exponent)

        return sparsity

将要剪枝的参数的mask值置0

  1. 选择mask为0的和本任务的权重(我之前一直以为mask为0的是要一直固定了的参数,现在看来还有待进一步解读)
  2. 取绝对值(只关注权重重要程度)
  3. 求出要剪枝的参数的权重的下限(权重小于该值的参数将被remove,即mask值为0,不参与计算)
    def _pruning_mask(self, weights, mask, layer_name, pruning_ratio):
        """Ranks weights by magnitude. Sets all below kth to 0.
           Returns pruned mask.
        """
        # Select all prunable weights, ie. belonging to current dataset.
        tensor = weights[mask.eq(self.current_dataset_idx) | mask.eq(0)] # This will flatten weights
        abs_tensor = tensor.abs()
        cutoff_rank = round(pruning_ratio * tensor.numel())
        try:
            cutoff_value = abs_tensor.cpu().kthvalue(cutoff_rank)[0].cuda() # value at cutoff rank
        except:
            print("Not enough weights for pruning, that is to say, too little space for new task, need expand the network.")
            sys.exit(2)

        # Remove those weights which are below cutoff and belong to current
        # dataset that we are training for.
        remove_mask = weights.abs().le(cutoff_value) * mask.eq(self.current_dataset_idx)

        # mask = 1 - remove_mask
        mask[remove_mask.eq(1)] = 0
        # print('Layer {}, pruned {}/{} ({:.2f}%)'.format(
        #        layer_name, mask.eq(0).sum(), tensor.numel(),
        #        float(100 * mask.eq(0).sum()) / tensor.numel()))
        return mask

val部分代码
只使用mask值 = 当前任务ID 的参数
此时绝大多数(>90%)的mask值均为1(当前任务ID),其余为0(被剪枝的参数)

    def validate(self, epoch_idx, biases=None):
        """Performs evaluation."""
        self.pruner.apply_mask()

main部分代码
训练集acc>0.95保存并记录,否则输出6

        elif args.mode == 'prune':
            if avg_train_acc > 0.95:
                json_data[args.target_sparsity] = round(avg_val_acc, 4)
                with open(args.pruning_ratio_to_acc_record_file, 'w') as json_file:
                    json.dump(json_data, json_file)
            else:
                sys.exit(6)

            must_pruning_ratio_for_curr_task = 0.0

            if args.network_width_multiplier == args.max_allowed_network_width_multiplier and json_data['0.0'] < baseline_acc:
                # If we reach the upperbound and still do not get the accuracy over our target on curr task, we still do pruning
                logging.info('we reach the upperbound and still do not get the accuracy over our target on curr task')
                remain_num_tasks = args.total_num_tasks - len(dataset_history)
                logging.info('remain_num_tasks: {}'.format(remain_num_tasks))
                ratio_allow_for_curr_task = round(1.0 / (remain_num_tasks + 1), 1)
                logging.info('ratio_allow_for_curr_task: {:.4f}'.format(ratio_allow_for_curr_task))
                must_pruning_ratio_for_curr_task = 1.0 - ratio_allow_for_curr_task
                if args.target_sparsity >= must_pruning_ratio_for_curr_task:
                    sys.exit(6)

之后持续性压缩模型,除了最后一次稀疏度是从0.9->0.95,其余每次+0.1
如果某一次运行输出6,则跳出循环

        if [ $? -ne 6 ]
        then
            for RUN_ID in `seq 1 9`; do
                nrof_epoch=$nrof_epoch_for_each_prune
                start_sparsity=$end_sparsity
                if [ $RUN_ID -lt 9 ]
                then
                    end_sparsity=$(bc <<< $end_sparsity+$pruning_ratio_interval)
                else
                    end_sparsity=$(bc <<< $end_sparsity+0.05)
                fi

                CUDA_VISIBLE_DEVICES=$GPU_ID python CPG_cifar100_main_normal.py \
                    --arch $arch \
                    --dataset ${dataset[task_id]} --num_classes $num_classes \
                    --lr $gradual_prune_lr \
                    --lr_mask 0.0 \
                    --batch_size $batch_size \
                    --weight_decay 4e-5 \
                    --save_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune \
                    --load_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune \
                    --epochs $nrof_epoch \
                    --mode prune \
                    --initial_sparsity=$start_sparsity \
                    --target_sparsity=$end_sparsity \
                    --pruning_frequency=10 \
                    --pruning_interval=4 \
                    --baseline_acc_file $baseline_cifar100_acc \
                    --network_width_multiplier $network_width_multiplier \
                    --max_allowed_network_width_multiplier $max_allowed_network_width_multiplier \
                    --pruning_ratio_to_acc_record_file checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune/record.txt \
                    --log_path checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/train.log \
                    --total_num_tasks $total_num_tasks

                if [ $? -eq 6 ]
                then
                    break
                fi
            done
        fi

Choose the checkpoint

    # Choose the checkpoint that we want
    python tools/choose_appropriate_pruning_ratio_for_next_task.py \
        --pruning_ratio_to_acc_record_file checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune/record.txt \
        --baseline_acc_file $baseline_cifar100_acc \
        --allow_acc_loss 0.0 \
        --dataset ${dataset[task_id]} \
        --max_allowed_network_width_multiplier $max_allowed_network_width_multiplier \
        --network_width_multiplier $network_width_multiplier \
        --log_path checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/train.log

挑选压缩比例最大的满足准确率要求的模型复制到文件夹下,删除其他所有模型文件
如果没有,则保存未压缩模型。

def main():
    args = parser.parse_args()
    if args.log_path:
        set_logger(args.log_path)

    save_folder = args.pruning_ratio_to_acc_record_file.rsplit('/', 1)[0]
    with open(args.baseline_acc_file, 'r') as jsonfile:
        json_data = json.load(jsonfile)
        criterion_acc = float(json_data[args.dataset])

    with open(args.pruning_ratio_to_acc_record_file, 'r') as json_file:
        json_data = json.load(json_file)
        acc_before_prune = json_data['0.0']
        json_data.pop('0.0')
        available_pruning_ratios = list(json_data.keys())
        available_pruning_ratios.reverse()
        flag_there_is_pruning_ratio_that_match_our_need = False

        chosen_pruning_ratio = 0.0
        for pruning_ratio in available_pruning_ratios:
            acc = json_data[pruning_ratio]
            #criterion_acc = min(criterion_acc, acc_before_prune)
            if (acc + args.allow_acc_loss >= criterion_acc) or (
                (args.network_width_multiplier == args.max_allowed_network_width_multiplier) and (acc_before_prune < criterion_acc)):
                chosen_pruning_ratio = pruning_ratio
                checkpoint_folder = os.path.join(save_folder, str(pruning_ratio))

                for filename in os.listdir(checkpoint_folder):
                    shutil.copyfile(os.path.join(checkpoint_folder, filename), os.path.join(save_folder, filename))
                flag_there_is_pruning_ratio_that_match_our_need = True
                break

        for pruning_ratio in available_pruning_ratios:
            checkpoint_folder = os.path.join(save_folder, str(pruning_ratio))
            shutil.rmtree(checkpoint_folder)

        if not flag_there_is_pruning_ratio_that_match_our_need:
            folder_that_contain_checkpoint_before_pruning = os.path.join(save_folder.rsplit('/', 1)[0], 'scratch')
            for filename in os.listdir(folder_that_contain_checkpoint_before_pruning):
                shutil.copyfile(os.path.join(folder_that_contain_checkpoint_before_pruning, filename), os.path.join(save_folder, filename))

        logging.info('We choose pruning_ratio {}'.format(chosen_pruning_ratio))

非首任务,且模型扩展可用

    if [ $task_id != 1 ] && [ $state -ne 5 ]
    ...
    ...
2020-06-13 09:34:33,604 - root - INFO - Before pruning: 
2020-06-13 09:34:33,604 - root - INFO - Sparsity range: 0.8 -> 0.9
2020-06-13 09:34:35,919 - root - INFO - In validate()-> Val Ep. #0 loss: 1.827, accuracy: 63.80, sparsity: 0.800, task1 ratio: 0.200, zero ratio: 0.800, mpl: 1.0
2020-06-13 09:34:35,919 - root - INFO - 
2020-06-13 09:34:46,044 - root - INFO - In train()-> Train Ep. #1 loss: 0.041, accuracy: 98.96, lr: 0.001, sparsity: 0.853, network_width_mpl: 1.0
2020-06-13 09:34:46,492 - root - INFO - In validate()-> Val Ep. #1 loss: 3.827, accuracy: 28.40, sparsity: 0.853, task1 ratio: 0.147, zero ratio: 0.853, mpl: 1.0
2020-06-13 09:34:54,027 - root - INFO - In train()-> Train Ep. #2 loss: 0.783, accuracy: 79.40, lr: 0.001, sparsity: 0.886, network_width_mpl: 1.0
2020-06-13 09:34:54,410 - root - INFO - In validate()-> Val Ep. #2 loss: 3.174, accuracy: 31.20, sparsity: 0.886, task1 ratio: 0.114, zero ratio: 0.886, mpl: 1.0
2020-06-13 09:35:01,881 - root - INFO - In train()-> Train Ep. #3 loss: 1.001, accuracy: 61.52, lr: 0.001, sparsity: 0.898, network_width_mpl: 1.0
2020-06-13 09:35:02,267 - root - INFO - In validate()-> Val Ep. #3 loss: 2.149, accuracy: 32.20, sparsity: 0.898, task1 ratio: 0.102, zero ratio: 0.898, mpl: 1.0
2020-06-13 09:35:09,895 - root - INFO - In train()-> Train Ep. #4 loss: 1.031, accuracy: 57.76, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:10,263 - root - INFO - In validate()-> Val Ep. #4 loss: 1.610, accuracy: 42.60, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:13,210 - root - INFO - In train()-> Train Ep. #5 loss: 0.874, accuracy: 64.96, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:13,659 - root - INFO - In validate()-> Val Ep. #5 loss: 0.939, accuracy: 61.40, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:16,912 - root - INFO - In train()-> Train Ep. #6 loss: 0.680, accuracy: 73.40, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:17,320 - root - INFO - In validate()-> Val Ep. #6 loss: 0.927, accuracy: 63.60, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:20,668 - root - INFO - In train()-> Train Ep. #7 loss: 0.622, accuracy: 75.32, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:21,131 - root - INFO - In validate()-> Val Ep. #7 loss: 0.951, accuracy: 62.20, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:24,232 - root - INFO - In train()-> Train Ep. #8 loss: 0.546, accuracy: 79.20, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:24,632 - root - INFO - In validate()-> Val Ep. #8 loss: 0.973, accuracy: 62.40, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:27,546 - root - INFO - In train()-> Train Ep. #9 loss: 0.522, accuracy: 79.76, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:27,979 - root - INFO - In validate()-> Val Ep. #9 loss: 0.978, accuracy: 64.00, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:30,952 - root - INFO - In train()-> Train Ep. #10 loss: 0.469, accuracy: 82.72, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:31,402 - root - INFO - In validate()-> Val Ep. #10 loss: 0.997, accuracy: 63.40, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:34,810 - root - INFO - In train()-> Train Ep. #11 loss: 0.447, accuracy: 83.20, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:35,258 - root - INFO - In validate()-> Val Ep. #11 loss: 1.039, accuracy: 62.60, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:38,566 - root - INFO - In train()-> Train Ep. #12 loss: 0.394, accuracy: 85.36, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:39,015 - root - INFO - In validate()-> Val Ep. #12 loss: 1.044, accuracy: 63.60, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:42,012 - root - INFO - In train()-> Train Ep. #13 loss: 0.381, accuracy: 84.96, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:42,406 - root - INFO - In validate()-> Val Ep. #13 loss: 1.093, accuracy: 63.40, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:45,560 - root - INFO - In train()-> Train Ep. #14 loss: 0.363, accuracy: 86.32, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:45,976 - root - INFO - In validate()-> Val Ep. #14 loss: 1.095, accuracy: 62.80, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:49,251 - root - INFO - In train()-> Train Ep. #15 loss: 0.349, accuracy: 87.52, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:49,661 - root - INFO - In validate()-> Val Ep. #15 loss: 1.120, accuracy: 64.00, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:53,390 - root - INFO - In train()-> Train Ep. #16 loss: 0.306, accuracy: 88.24, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:53,801 - root - INFO - In validate()-> Val Ep. #16 loss: 1.154, accuracy: 63.20, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:35:57,530 - root - INFO - In train()-> Train Ep. #17 loss: 0.270, accuracy: 90.08, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:35:57,934 - root - INFO - In validate()-> Val Ep. #17 loss: 1.204, accuracy: 63.80, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:36:01,274 - root - INFO - In train()-> Train Ep. #18 loss: 0.287, accuracy: 89.40, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:36:01,698 - root - INFO - In validate()-> Val Ep. #18 loss: 1.204, accuracy: 62.80, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:36:12,408 - root - INFO - In train()-> Train Ep. #19 loss: 0.241, accuracy: 91.12, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:36:12,831 - root - INFO - In validate()-> Val Ep. #19 loss: 1.256, accuracy: 63.00, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:36:15,761 - root - INFO - In train()-> Train Ep. #20 loss: 0.251, accuracy: 90.36, lr: 0.001, sparsity: 0.900, network_width_mpl: 1.0
2020-06-13 09:36:16,260 - root - INFO - In validate()-> Val Ep. #20 loss: 1.278, accuracy: 63.80, sparsity: 0.900, task1 ratio: 0.100, zero ratio: 0.900, mpl: 1.0
2020-06-13 09:36:16,261 - root - INFO - ----------------
2020-06-13 09:36:17,662 - root - INFO - We choose pruning_ratio 0.8

观察日志发现训练集准确率<95%,则压缩提前停止

Task2 (k>1)

finetune mode

和Task一开始的训练区别仅在于load_folder

    # Training the network on current tasks
    state=2
    while [ $state -eq 2 ]; do
        if [ "$task_id" != "1" ]
        then
            CUDA_VISIBLE_DEVICES=$GPU_ID python ../CPG_cifar100_main_normal.py \
                --arch $arch \
                --dataset ${dataset[task_id]} --num_classes $num_classes \
                --lr $lr \
                --lr_mask $lr_mask \
                --batch_size $batch_size \
                --weight_decay 4e-5 \
                --save_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/scratch \
                --load_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id-1]}/gradual_prune \
                --epochs $finetune_epochs \
                --mode finetune \
                --network_width_multiplier $network_width_multiplier \
                --max_allowed_network_width_multiplier $max_allowed_network_width_multiplier \
                --baseline_acc_file $baseline_cifar100_acc \
                --pruning_ratio_to_acc_record_file checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune/record.txt \
                --log_path checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/train.log \
                --total_num_tasks $total_num_tasks

终于见到我们的piggymasks了,初始化为模型参数shape的全0.01值。

        task_id = model.module.datasets.index(args.dataset) + 1
        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                    piggymasks[name] = torch.zeros_like(masks['module.' + name], dtype=torch.float32)
                    piggymasks[name].fill_(0.01)
                    piggymasks[name] = Parameter(piggymasks[name])
                    module.piggymask = piggymasks[name]

如此一来,我们就可以训练mask了:

        elif 'piggymask' in name:
            masks_to_optimize_via_Adam.append(param)
            named_of_masks_to_optimize_via_Adam.append(name)
    if masks_to_optimize_via_Adam:
        optimizer_mask = optim.Adam(masks_to_optimize_via_Adam, lr=lr_mask)
        optimizers.add(optimizer_mask, lr_mask)

读取模型也是个麻烦的工作:
最后几层(分类线性层)的参数和piggymask跳过,前者是因为最后几层直接根据任务训练即可,后者则用来对应每个任务进行选择,无需继承之前的数据。

    def load_checkpoint(self, optimizers, resume_from_epoch, save_folder):

        if resume_from_epoch > 0:
            filepath = self.args.checkpoint_format.format(save_folder=save_folder, epoch=resume_from_epoch)
            checkpoint = torch.load(filepath)
            checkpoint_keys = checkpoint.keys()
            state_dict = checkpoint['model_state_dict']
            curr_model_state_dict = self.model.module.state_dict()

            for name, param in state_dict.items():
                if ('piggymask' in name or name == 'classifier.weight' or name == 'classifier.bias' or
                    (name == 'classifier.0.weight' or name == 'classifier.0.bias' or name == 'classifier.1.weight')):
                    # I DONT WANT TO DO THIS! QQ That last 3 exprs are for anglelinear and embeddings
                    continue
                elif len(curr_model_state_dict[name].size()) == 4:
                    # Conv layer
                    curr_model_state_dict[name][:param.size(0), :param.size(1), :, :].copy_(param)
                elif len(curr_model_state_dict[name].size()) == 2 and 'features' in name:
                    # FC conv (feature layer)
                    curr_model_state_dict[name][:param.size(0), :param.size(1)].copy_(param)
                elif len(curr_model_state_dict[name].size()) == 1:
                    # bn and prelu layer
                    curr_model_state_dict[name][:param.size(0)].copy_(param)
                elif 'classifiers' in name:
                    curr_model_state_dict[name][:param.size(0), :param.size(1)].copy_(param)
                else:
                    try:
                        curr_model_state_dict[name].copy_(param)
                    except:
                        pdb.set_trace()
                        print("There is some corner case that we haven't tackled")
        return
    elif args.mode == 'finetune':
        if not args.finetune_again:
            manager.pruner.make_finetuning_mask()
            logging.info('Finetune stage...')

再看一遍这个函数,是不是非常亲切,意图明确?将之前剪枝的参数的Mask设为本次任务ID。
这些参数即为:与任务k相关的释放(冗余)权重 W k E W^E_{k} WkE 是可以用于后续任务的额外权重。

    def make_finetuning_mask(self):
        """Turns previously pruned weights into trainable weights for
           current dataset.
        """
        assert self.masks
        self.current_dataset_idx += 1

        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                mask = self.masks[name]
                mask[mask.eq(0)] = self.current_dataset_idx

此时,模型的forward函数分支选择也有了变化

    def forward(self, input, layer_info=None, name=None):
        if self.piggymask is not None:
            # Get binarized/ternarized mask from real-valued mask.
            mask_thresholded = self.threshold_fn(self.piggymask, self.info['threshold'])
            # Mask weights with above mask.
            weight = mask_thresholded * self.weight
        else:
            weight = self.weight

这里的threshold_fn

        if threshold_fn == 'binarizer':
            # print('Calling binarizer with threshold:', threshold)
            self.threshold_fn = Binarizer.apply
class Binarizer(torch.autograd.Function):
    """Binarizes {0, 1} a real valued tensor."""

    @staticmethod
    def forward(ctx, inputs, threshold):
        outputs = inputs.clone()
        outputs[inputs.le(threshold)] = 0
        outputs[inputs.gt(threshold)] = 1
        return outputs

    @staticmethod
    def backward(ctx, grad_out):
        return grad_out, None

由于二值化的掩码不可微,在训练二值化的掩码 M M M 时,我们在后向过程中更新实值掩码 M ^ \hat{M} M^。然后 M M M 通过一个 M ^ \hat{M} M^ 上的阈值进行量化并应用到前向计算。
小于等于阈值为0,大于阈值为1的二值Tensor。
另外这里的阈值为0.005,也就解释了为什么初始化为0.01,也即初始使用所有的参数。

    def forward(self, input, layer_info=None, name=None):
        if self.piggymask is not None:
            # Get binarized/ternarized mask from real-valued mask.
            mask_thresholded = self.threshold_fn(self.piggymask, self.info['threshold'])
            # Mask weights with above mask.
            weight = mask_thresholded * self.weight
        else:
            weight = self.weight

        # Perform conv using modified weight.
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

这样就相当于引入了可训练的mask来选择使用哪些权重。

之后又到了我们熟悉的函数

                # Set fixed param grads to 0.
                self.pruner.do_weight_decay_and_make_grads_zero()
    def do_weight_decay_and_make_grads_zero(self):
        """Sets grads of fixed weights to 0."""
        assert self.masks
        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                mask = self.masks[name]
                # Set grads of all weights not belonging to current dataset to 0.
                if module.weight.grad is not None:
                    module.weight.grad.data.add_(self.args.weight_decay, module.weight.data)
                    module.weight.grad.data[mask.ne(
                        self.current_dataset_idx)] = 0
                if module.piggymask is not None and module.piggymask.grad is not None:
                    if self.args.mode == 'finetune':
                        module.piggymask.grad.data[mask.eq(0) | mask.ge(self.current_dataset_idx)] = 0
                    elif self.args.mode == 'prune':
                        module.piggymask.grad.data.fill_(0)
        return

还记得我们之前的分析:如果piggymask不为空且有梯度更新:

  • finetune:已经固定的参数和 >= 本次任务 的 piggymask不更新
  • prune:piggymask完全不更新梯度

那么我们现在进一步推测,这里已固定的参数应该是之后任务K压缩过程mask值置0的过程,而在任务K的finetune阶段,这时mask为任务K-1压缩过程中剪枝的参数的Mask,所以可以被用来训练!

好的,我们再一次来到val阶段的这个函数,是不是要看吐了哈哈

    def apply_mask(self):
        """To be done to retrieve weights just for a particular dataset."""
        for name, module in self.model.named_modules():
            if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                weight = module.weight.data
                mask = self.masks[name].cuda()
                weight[mask.eq(0)] = 0.0
                weight[mask.gt(self.inference_dataset_idx)] = 0.0
        return

为什么又拿出来说一遍呢,因为此时是finetune阶段,所以sparsity为0,也就是说没有值为0的mask,不说明这一点,细心的朋友应该就要疑惑了。

现在保存模型是需要保存piggymasks了

    if avg_train_acc > 0.95:
        manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)

gradually pruning

bash命令类同Task1

加载checkpoint
现在相信我们已经明确了:
Mask中为TaskID mask,用于标记权重归属与哪个Task
piggymasks保存于shared_layer_info中,且Task1是不存在piggymask的。
而这个阶段的目标我们来复习一下:

T a s k   k + 1 Task~ k+1 Task k+1的压缩:
经过 M M M W k E W^E_{k} WkE 的学习,得到了 T a s k   k + 1 Task~k+1 Task k+1 的初始模型。然后,我们固定掩码 M M M 并对 W k E W^E_{k} WkE 进行逐步剪枝,从而得到 T a s k   k + 1 Task~k+1 Task k+1 的压缩模型 W k + 1 P W^P_{k+1} Wk+1P 和冗余(被释放)的权重 W k + 1 E W^E_{k+1} Wk+1E 旧任务的压缩模型随后变成 W 1 : ( k + 1 ) P = W 1 : k P ∪ W k + 1 P W^P_{1:(k+1)}=W^P_{1:k} \cup W^P_{k+1} W1:(k+1)P=W1:kPWk+1P 从一个任务到另一个任务的压缩和选择/扩展循环是重复的。

所以prune阶段piggymasks是不更新的

    if resume_from_epoch:
        filepath = args.checkpoint_format.format(save_folder=resume_folder, epoch=resume_from_epoch)
        checkpoint = torch.load(filepath)
        checkpoint_keys = checkpoint.keys()
        dataset_history = checkpoint['dataset_history']
        dataset2num_classes = checkpoint['dataset2num_classes']
        masks = checkpoint['masks']
        shared_layer_info = checkpoint['shared_layer_info']
	if args.dataset not in shared_layer_info:
		...
		...
	elif args.finetune_again:
		...
		...
		
    else:
        #try:
        piggymasks = shared_layer_info[args.dataset]['piggymask']
        #except:
        #    piggymasks = {}
        task_id = model.module.datasets.index(args.dataset) + 1
        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                    module.piggymask = piggymasks[name]
    shared_layer_info[args.dataset]['network_width_multiplier'] = args.network_width_multiplier

从Task2开始需要读取piggmask

当我们再次来到这里的时候,是不是有种不能再熟悉的感觉?
将所有非本次任务的参数梯度更新消除。

                # Set fixed param grads to 0.
                self.pruner.do_weight_decay_and_make_grads_zero()

总的来说,压缩过程Task2和Task1除了piggymask的读取外几乎没有区别。

Retrain piggymask and weight

压缩之后我们要重新微调,而此时的piggymask也要进行重置。

    elif args.finetune_again:
       # reinitialize piggymask
       piggymasks = {}
       for name, module in model.module.named_modules():
           if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
               piggymasks[name] = torch.zeros_like(masks['module.' + name], dtype=torch.float32)
               piggymasks[name].fill_(0.01)
               piggymasks[name] = Parameter(piggymasks[name])
               module.piggymask = piggymasks[name]

训练前再验证一下

        else:
            logging.info('Piggymask Retrain...')
            history_best_avg_val_acc_when_retraining = manager.validate(start_epoch-1)
            num_epochs_that_criterion_does_not_get_better = 0

如果后续表现更好则保存新模型,如果持续5次效果没有提升则终止retrain

        if args.finetune_again:
            if avg_val_acc > history_best_avg_val_acc_when_retraining:
                history_best_avg_val_acc_when_retraining = avg_val_acc

                num_epochs_that_criterion_does_not_get_better = 0
                if args.save_folder is not None:
                    for path in os.listdir(args.save_folder):
                        if '.pth.tar' in path:
                            os.remove(os.path.join(args.save_folder, path))
                else:
                    print('Something is wrong! Block the program with pdb')
                    pdb.set_trace()

                history_best_avg_val_acc = avg_val_acc
                manager.save_checkpoint(optimizers, epoch_idx, args.save_folder)
            else:
                num_epochs_that_criterion_does_not_get_better += 1

            if args.finetune_again and num_epochs_that_criterion_does_not_get_better == 5:
                logging.info("stop retraining")
                sys.exit(0)

其余没有变化。

Retrain piggymask and weight

# If there is any improve from retraining, use that checkpoint
        python tools/choose_retrain_or_not.py \
            --save_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/gradual_prune \
            --load_folder checkpoints/CPG/experiment1/$setting/$arch/${dataset[task_id]}/retrain

有新存档就用呗

    args = parser.parse_args()
    src_filenames = os.listdir(args.load_folder)
    for src_filename in src_filenames:
        if '.pth.tar' in src_filename:
            out_paths = os.listdir(args.save_folder)
            for checkpoint_file in out_paths:
                if '.pth.tar' in checkpoint_file:
                    os.remove(os.path.join(args.save_folder, checkpoint_file))
            
            shutil.copyfile(os.path.join(args.load_folder, src_filename), os.path.join(args.save_folder, src_filename))
            break

那么之后的Task 3-n 除了会使用retrain的piggymask之外只是单纯的重复过程

        if task_id > 1:
            for name, module in model.module.named_modules():
                if isinstance(module, nl.SharableConv2d) or isinstance(module, nl.SharableLinear):
                    module.piggymask = piggymasks[name]

Growing

现在,我们只有一个部分还存在疑问了,那就是扩展的实现方式。
模型在每次训练之初就已经确定,所以模型宽度的扩展仅发生在.sh中

        if [ $state -eq 2 ]
        then
            network_width_multiplier=$(bc <<< $network_width_multiplier+0.5)
            echo "New network_width_multiplier: $network_width_multiplier"
            continue

那么我们再来回顾下这个state 2
Compacting, Picking and Growing for Unforgetting Continual Learning 论文及代码流程解读_第4张图片
也就是说训练集或验证集上的精度无法达到要求,且模型宽度未达到上限,便会进行扩展。

接下来让我们追踪一下这个过程:
老规矩,先看一下日志,这里是第一次发生扩展的所在

2020-06-13 10:33:43,676 - root - INFO - In train()-> Train Ep. #99 loss: 0.019, accuracy: 99.48, lr: 0.0001, sparsity: 0.000, network_width_mpl: 1.0
2020-06-13 10:33:44,244 - root - INFO - In validate()-> Val Ep. #99 loss: 1.096, accuracy: 76.60, sparsity: 0.000, task3 ratio: 0.760, zero ratio: 0.000, mpl: 1.0, shared_ratio: 0.758
2020-06-13 10:33:51,253 - root - INFO - In train()-> Train Ep. #100 loss: 0.017, accuracy: 99.48, lr: 0.0001, sparsity: 0.000, network_width_mpl: 1.0
2020-06-13 10:33:51,796 - root - INFO - In validate()-> Val Ep. #100 loss: 1.113, accuracy: 76.60, sparsity: 0.000, task3 ratio: 0.760, zero ratio: 0.000, mpl: 1.0, shared_ratio: 0.758
2020-06-13 10:33:52,680 - root - INFO - ----------------
2020-06-13 10:33:52,680 - root - INFO - It's time to expand the Network
2020-06-13 10:33:52,681 - root - INFO - Auto expand network
2020-06-13 10:34:03,947 - root - INFO - Finetune stage...
2020-06-13 10:34:16,762 - root - INFO - In train()-> Train Ep. #1 loss: 1.240, accuracy: 48.84, lr: 0.01, sparsity: 0.000, network_width_mpl: 1.224744871391589
2020-06-13 10:34:17,606 - root - INFO - In validate()-> Val Ep. #1 loss: 1.143, accuracy: 52.60, sparsity: 0.000, task3 ratio: 1.260, zero ratio: 0.000, mpl: 1.224744871391589, shared_ratio: 0.928
2020-06-13 10:34:28,052 - root - INFO - In train()-> Train Ep. #2 loss: 1.086, accuracy: 55.56, lr: 0.01, sparsity: 0.000, network_width_mpl: 1.224744871391589
2020-06-13 10:34:28,711 - root - INFO - In validate()-> Val Ep. #2 loss: 1.079, accuracy: 59.40, sparsity: 0.000, task3 ratio: 1.260, zero ratio: 0.000, mpl: 1.224744871391589, shared_ratio: 0.907
2020-06-13 10:34:38,802 - root - INFO - In train()-> Train Ep. #3 loss: 1.043, accuracy: 57.84, lr: 0.01, sparsity: 0.000, network_width_mpl: 1.224744871391589

那么为什么会发生扩展呢,因为baseline中验证集准确率为0.77。

"flowers": "0.7700"

那么我们再来复习一下同时也是看一下在扩展的情况下如果初始化模型
这一次不再标层级,读者如果需要可以参考上文。
改变的只有network_width_multiplier 1.0 -> sqrt(1.5)

def custom_vgg_cifar100(custom_cfg, dataset_history=[], dataset2num_classes={}, network_width_multiplier=1.0, groups=1, shared_layer_info={}, **kwargs):
    return VGG(make_layers_cifar100(custom_cfg, network_width_multiplier, batch_norm=True, groups=groups), dataset_history, 
        dataset2num_classes, network_width_multiplier, shared_layer_info, **kwargs)

定义时的扩展,这点好办

def make_layers_cifar100(cfg, network_width_multiplier, batch_norm=False, groups=1):
    layers = []
    in_channels = 3

    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            if in_channels == 3:
                conv2d = nl.SharableConv2d(in_channels, int(v * network_width_multiplier), kernel_size=3, padding=1, bias=False)
            else:
                conv2d = nl.SharableConv2d(in_channels, int(v * network_width_multiplier), kernel_size=3, padding=1, bias=False, groups=groups)

            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(int(v * network_width_multiplier)), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = int(v * network_width_multiplier)

    layers += [
        View(-1, int(512*network_width_multiplier)),
        nl.SharableLinear(int(512*network_width_multiplier), int(4096*network_width_multiplier)),
        nn.ReLU(True),
        nl.SharableLinear(int(4096*network_width_multiplier), int(4096*network_width_multiplier)),
        nn.ReLU(True),
    ]

    return nn.Sequential(*layers)
class VGG(nn.Module):
    def __init__(self, features, dataset_history, dataset2num_classes, network_width_multiplier=1.0, shared_layer_info={}, init_weights=True, progressive_init=False):
        super(VGG, self).__init__()
        self.features = features
        self.network_width_multiplier = network_width_multiplier
        self.shared_layer_info = shared_layer_info
        # self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        
        self.datasets, self.classifiers = dataset_history, nn.ModuleList()
        self.dataset2num_classes = dataset2num_classes

        if self.datasets:
            self._reconstruct_classifiers()

        if init_weights:
            self._initialize_weights()

        if progressive_init:
            self._initialize_weights_2()

关于_reconstruct_classifiers()
其实是重新构建之前任务的分类器,根据每个任务自己的模型宽度来重构。

    def _reconstruct_classifiers(self):
        for dataset, num_classes in self.dataset2num_classes.items():
            self.classifiers.append(nn.Linear(int(self.shared_layer_info[dataset]['network_width_multiplier'] * 4096), num_classes))

之后是对新参数的mask转移,简单粗暴,按序填充如新的SharableConv2d shape为(78,3,3,3),则将原参数填充到前(64,3,3,3)的位置即可

        if NEED_ADJUST_MASK:
            if args.mode == 'finetune':
                for name, module in model.named_modules():
                    if isinstance(module, nl.SharableConv2d):
                        mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                        if 'cuda' in module.weight.data.type():
                            mask = mask.cuda()
                        mask[:masks[name].size(0), :masks[name].size(1), :, :].copy_(masks[name])
                        masks[name] = mask
                    elif isinstance(module, nl.SharableLinear):
                        mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                        if 'cuda' in module.weight.data.type():
                            mask = mask.cuda()
                        mask[:masks[name].size(0), :masks[name].size(1)].copy_(masks[name])
                        masks[name] = mask

模型的参数读取也是类似的。

    def load_checkpoint(self, optimizers, resume_from_epoch, save_folder):

        if resume_from_epoch > 0:
            filepath = self.args.checkpoint_format.format(save_folder=save_folder, epoch=resume_from_epoch)
            checkpoint = torch.load(filepath)
            checkpoint_keys = checkpoint.keys()
            state_dict = checkpoint['model_state_dict']
            curr_model_state_dict = self.model.module.state_dict()

            for name, param in state_dict.items():
                if ('piggymask' in name or name == 'classifier.weight' or name == 'classifier.bias' or
                    (name == 'classifier.0.weight' or name == 'classifier.0.bias' or name == 'classifier.1.weight')):
                    # I DONT WANT TO DO THIS! QQ That last 3 exprs are for anglelinear and embeddings
                    continue
                elif len(curr_model_state_dict[name].size()) == 4:
                    # Conv layer
                    curr_model_state_dict[name][:param.size(0), :param.size(1), :, :].copy_(param)
                elif len(curr_model_state_dict[name].size()) == 2 and 'features' in name:
                    # FC conv (feature layer)
                    curr_model_state_dict[name][:param.size(0), :param.size(1)].copy_(param)
                elif len(curr_model_state_dict[name].size()) == 1:
                    # bn and prelu layer
                    curr_model_state_dict[name][:param.size(0)].copy_(param)
                elif 'classifiers' in name:
                    curr_model_state_dict[name][:param.size(0), :param.size(1)].copy_(param)
                else:
                    try:
                        curr_model_state_dict[name].copy_(param)
                    except:
                        pdb.set_trace()
                        print("There is some corner case that we haven't tackled")
        return

这样一来整体的情况就比较明晰了。

限于个人水平,本文在梳理过程中仍存在纰漏和不足,如有问题,请评论告知!

你可能感兴趣的:(pytorch,模型压缩,模型扩展,跨域,多任务)