weight pruning(权重剪枝)--学习笔记

  1. 权重剪枝是一种模型优化技术。在权重剪枝中,它在训练过程中逐渐将模型权重归零,以实现模型稀疏。
  2. 深度模型通常会有更好的预测精度,但是它面临计算开销过大的问题。模型压缩(model compress)是提高深度模型推理效率的一种解决方案,它期望在不损失精度或者精度损失可控的范围内,加速推理效率,减低内存开销。
    1. 模型压缩算法主要包括权重量化(quantization)、剪枝(pruning)、低秩分解等
    2. 量化需要硬件或者推理引擎的对低精度8-bit计算支持,目前tensorflow在x86和gpu环境下还没有很好的支持,因此量化只帮助实现了模型大小下降,没有实现推理的加速。model pruning学习的材料是tensorflow repo中的tensorflow/contrib/model_pruning,实际了解后发现它属于pruning中no-structural pruning,其加速效果依赖具体的硬件实现,加速效果一般,tensorflow 中对稀疏矩阵运算没有特别好的优化(依赖于底层的 SparseBLAS 实现,目前还没有特别好的)。
    3. 有些深度学习网络模型over-parameterized,为了使其在资源受限的环境下高效的进行推理预测,要么减少网络的隐藏单元(hidden unit)同时保持模型密集连接结构,要么采用针对大模型进行模型剪枝(model pruning)。
    4. 大而稀疏的模型优于小而密集的模型,
  3. tensorflow中的模型剪枝是一种训练时剪枝。对于需要被剪枝的网络模型,对于网络中每个需要被剪枝的层(layer)添加一个二进制掩码变量(binary mask variable ),该变量的大小和形状和改层的权重张量(weight tensor)相同。在训练图中加入一些ops,它负责对该层的权重值(weights)的绝对值进行排序,通过mask将最小的权重值屏蔽为0。在前向传播时该掩模的对应位与选中权重进行相与输出feature map,如果该掩模对应位为0则对应的权重相与后则为0,在反向传播时掩模对应位为0的权重参数则不参与更新。除此之外,文章提出了一种新的自动逐步修剪算法(automated gradual pruning),它实际上是定义了一种稀疏度变化的规则,初始时刻,稀疏度提升较快,而越到后面,稀疏度提升速度会逐渐放缓,这个主要是基于冗余度的考虑。因为初始时有大量冗余的权值,而越到后面保留的权值数量越少,不能再“大刀阔斧”地修剪,而需要更谨慎些,避免“误伤无辜”。
  4. 剪枝(pruning)丢弃了不严重影响模型表现的权重。
  5. 在优化过程中,有的权重比其他权重用更大的幅度量级(有正有负)进行更新,这些权重可以看作”更重要“的权重。
  6. 训练结束之后,我们检查网络每一层的权重幅值,并找出”重要“的权重。寻找方法如下(heuristics):
    1.  降序排列权重幅值
    2. 找到在队列中更早出现的那些幅值(对应weight maginitudes更大)”那些“具体有多少,取决于有百分之多少的权重需要被剪枝。(percentage of weights to be pruned)
    3. 设定一个阈值,权重幅值在阈值之上的权重会被视为是重要的权重。这个阈值的设定也有以下几种方法:
      1. 这个阈值可以是整个网络中最小的权重梯度
      2. 这个阈值可以是该网络中某一层的最小权重阈值。在这种情况下,不同层的“重要”权重之间是有偏差的。
  7. 对训练好的神经网络剪枝
    1. 先取出权重,然后进行从小到大排列。基于稀疏百分比(sparsity_percentage=0.7),把权重中的从小到大排列的前百分之七十的权重设置为0。
    2. 实现代码:
      # 复制内核权重并获得列式L2规范的排名索引
      kernel_weights = np.copy(k_weights)
      ind = np.argsort(np.linalg.norm(kernel_weights, axis=0))
      
      # 要设置为0的索引数
      sparsity_percentage = 0.7
      cutoff = int(len(ind)*sparsity_percentage)
      
      # 将2D内核权重矩阵中的索引设置为0
      sparse_cutoff_inds = ind[0:cutoff]
      kernel_weights[:,sparse_cutoff_inds] = 0.

你可能感兴趣的:(TensorFlow,tensorflow,深度学习,python)