Paper Reading: Variational Convolutional Neural Network Pruning

A pruning method that based on Batch Normalization layer.
Variational Convolutional Neural Network Pruning

Content

  • Brief introduction
  • Implementation
    • What is channel saliency
    • How to make good use of channel saliency parameter
    • Meansuring distribution of channel saliency parameter
      • In an easy-understanding way
      • In a detailed way
        • Introduction to variational inference
        • Apply variational inference to estimating channel saliency
  • Experiment result
    • CIFAR10 dataset
    • ImageNet2012 dataset
  • My conclusion
  • Reference

Brief introduction

To prune a filter in a convolutional layer, we should have a criterion that measures how this filter performs. In this paper, the author propose a new criterion called channel saliency which comes from batch normalization layer.

Implementation

What is channel saliency

Below is the implementation of batch normalization layer
B N ( x ) = x − μ B σ B 2 + ϵ BN(x)=\frac{x-\mu_{B}}{\sqrt{\sigma_B^2+\epsilon}} BN(x)=σB2+ϵ xμB
where μ B \mu_B μB and σ B 2 \sigma_B^2 σB2 is mean and variance of batch data, and ϵ \epsilon ϵ is a small constant(1e-5 or similar) that prevent zero occurs on denominator.

Then x o u t = γ B N ( x ) + β x_{out}=\gamma BN(x)+\beta xout=γBN(x)+β can be rewritten as
x o u t = γ B N ( x ) + β ~ x_{out}=\gamma BN(x)+\tilde{\beta} xout=γBN(x)+β~
where β ~ = γ β \tilde{\beta}=\gamma\beta β~=γβ. And here we got
x o u t = γ B N ( x ) + β ~ = γ [ B N ( x ) + β ] x_{out}=\gamma BN(x)+\tilde{\beta}=\gamma[BN(x)+\beta] xout=γBN(x)+β~=γ[BN(x)+β]
The paramrter γ \gamma γ here can be viewed as a criterion that measures importance of related filter directily. So we givethe parameter γ \gamma γ a name called channel saliency.
Below, I just use γ \gamma γ to represent channel saliency.

How to make good use of channel saliency parameter

It’s explict that we should just prune filters that have a small γ \gamma γ in the following BN layer. The problem is that γ \gamma γ changes drastically after several iterations in training stage, so the author analyzes the distribution of γ \gamma γ. When the distribution is around zero, then we can prune safely.

Meansuring distribution of channel saliency parameter

In an easy-understanding way

  1. Using a new distribution q ϕ ( γ ) q_{\phi}(\gamma) qϕ(γ) to estimate the distribution of γ \gamma γ : p ( γ ) p(\gamma) p(γ).
  2. Create a new loss function: L = L d − K L ( q ϕ ( γ ) ∣ ∣ p ( γ ) ) \mathcal{L}=\mathcal{L}_d-KL(q_{\phi}(\gamma)||p(\gamma)) L=LdKL(qϕ(γ)p(γ))where L d \mathcal{L}_d Ld is the original loss function, and KL is KL-divergenc in short.
  3. KL-divergnce added to loss function can reduce the distance between q ϕ ( γ ) q_{\phi}(\gamma) qϕ(γ) and p ( γ ) p(\gamma) p(γ).
  4. p ( γ ) p(\gamma) p(γ) can be read from caculating result of a mini-batch, then we got p ( γ i ) ∼ N ( μ i , σ i ) p(\gamma_i)\sim \mathcal{N}(\mu_i,\sigma_i) p(γi)N(μi,σi)where μ i \mu_i μi and σ i \sigma_i σi are what I need.
  5. q ϕ ( γ ) q_{\phi}(\gamma) qϕ(γ) is gnerated by us and q ϕ ( γ i ) ∼ N ( μ i ∗ , σ i ∗ ) q_{\phi}(\gamma_i)\sim\mathcal{N}(\mu_i^*, \sigma_i^*) qϕ(γi)N(μi,σi)where μ i ∗ \mu_i^* μi and σ i ∗ \sigma_i^* σi are what we expect the network to learn.
    The author fix mean to zero and list out: q ϕ ( γ i ) ∼ N ( 0 , σ i ∗ ) q_{\phi}(\gamma_i)\sim\mathcal{N}(0, \sigma_i^*) qϕ(γi)N(0,σi)I do not think it right.
  6. Then KL-divergence can be simplified as K L ( q ϕ ( γ ) ∣ ∣ p ( γ ) ) = ∑ K L ( q ϕ ( γ i ) ∣ ∣ p ( γ i ) ) = ∑ [ σ i ∗ σ i + σ i ∗ + ( μ i − μ i ∗ ) 2 2 ( σ i ∗ ) 2 − 1 2 ] KL(q_{\phi}(\gamma)||p(\gamma))=\sum KL(q_{\phi}(\gamma_i)||p(\gamma_i))=\sum [\frac{\sigma_i^*}{\sigma_i}+\frac{\sigma_i^*+(\mu_i-\mu_i^*)^2}{2(\sigma_i^*)^2}-\frac{1}{2}] KL(qϕ(γ)p(γ))=KL(qϕ(γi)p(γi))=[σiσi+2(σi)2σi+(μiμi)221]
  7. Finally, just backprograte loss function L \mathcal{L} L and we can get estimated distribution of γ \gamma γ: μ i ∗ \mu_i^* μi and σ i ∗ \sigma_i^* σi. When ( μ i ∗ , σ i ∗ ) (\mu_i^*,\sigma_i^*) (μi,σi) smaller then a threshold ( μ T , σ T ) (\mu_T,\sigma_T) (μT,σT), we just prune this filter.

In a detailed way

Introduction to variational inference

The general idea of VI(variational inference) is:

  1. Define a flexible famliy of distributions over the hidden variables indexed by free parameters.
  2. Find the setting of parameters that make the distribution family closest to the desired posterior distribution.
  3. Thus the problem of finding distribution becomes a problem of optimization.

Let’s start from Bayes Rule
P ( A i ∣ B ) = P ( B ∣ A i ) P ( A i ) ∑ j P ( B ∣ A j ) P ( A j ) P(A_i|B)=\frac{P(B|A_i)P(A_i)}{\sum_jP(B|A_j)P(A_j)} P(AiB)=jP(BAj)P(Aj)P(BAi)P(Ai)
For a general problem, we want to find out the posterior distribution p ( z ∣ x ) p(z|x) p(zx), in fact
p ( z ∣ x ) = p ( z ) p ( x ∣ z ) p ( x ) p(z|x)=\frac{p(z)p(x|z)}{p(x)} p(zx)=p(x)p(z)p(xz)
where p ( z ) p(z) p(z) is the prior distribution and p ( x ∣ z ) p(x|z) p(xz) is the likehood function. The problem is, however, p ( x ) = ∫ p ( z ) p ( x ∣ z ) p(x)=\int p(z)p(x|z) p(x)=p(z)p(xz) is usually not computational-tractable. So VI method use another variational distribution q θ ( z ) q_\theta(z) qθ(z) that depends on parameter θ \theta θ to appoximate p ( z ∣ x ) p(z|x) p(zx). Thus we apply the following function to optimize:
min ⁡ θ K L ( q θ ( z ) ∣ ∣ p ( z , x ) ) (2.3.1) \min\limits_{\theta} KL(q_\theta(z)||p(z,x)) \tag{2.3.1} θminKL(qθ(z)p(z,x))(2.3.1)
where K L KL KL is Kullback–Leibler divergence(KL divergence).
The function above is hard to optimize because we don’t know p ( z , x ) p(z,x) p(z,x). Then some transfer this function to:
max ⁡ θ E L B O ( θ , x ) = E q ( z , θ ) [ log ⁡ p ( x , z ) q θ ( z ) ] (2.3.2) \max\limits_{\theta}ELBO(\theta,x)=\mathbb{E}_{q(z,\theta)}[\log{\frac{p(x,z)}{q_\theta(z)}}] \tag{2.3.2} θmaxELBO(θ,x)=Eq(z,θ)[logqθ(z)p(x,z)](2.3.2)
ELBO here is Evidence Lower Bound in short. In many cases, p ( x , z ) p(x,z) p(x,z) is our model in machine learning.
There exists a lot of methods that can optimize function 2.3.2 2.3.2 2.3.2, including mean-field assumption, stochastic variational inference (SVI), black box variational inference (BBVI) , reparameterization tricks etc.

Apply variational inference to estimating channel saliency

In this paper, the concrete situation is
p ( γ ∣ D ) = p ( γ ) p ( D ∣ γ ) p ( γ ) p(\gamma|\mathcal{D})=\frac{p(\gamma)p(\mathcal{D}|\gamma)}{p(\gamma)} p(γD)=p(γ)p(γ)p(Dγ)
where D \mathcal{D} D dataset. However, p ( D ) = ∫ p ( γ ) p ( D ∣ γ ) ) = ∫ p ( D , γ ) d γ p(\mathcal{D})=\int p(\gamma)p(\mathcal{D}|\gamma))=\int p(\mathcal{D},\gamma)d\gamma p(D)=p(γ)p(Dγ))=p(D,γ)dγ is computational-intractable.
Contents for detailed is too borining and I don’t think it necessary to analyze. The author describe those manture technologies in detail, which is unnecessary. Just refer to easy-understanding way will be enough.

Experiment result

There exists results on CIFAR10, CIFAR100 and ImageNet2012 dataset, but I just list results on CIFAR10 and ImageNet2012 here.

CIFAR10 dataset

Result on VGG16 does not show much improvement. Result on ResNet differs quite a lot from my data in FLOPs. FLOPs shown here for original model is much smaller than my calculation, which makes it inconvincing in my eyes.

ImageNet2012 dataset

Experiments on ImageNet2012 does not achieve state-of-the-art result.

My conclusion

This method use bach normalization layer to determine which layer to prune. Experiments shown here can not achieve statre-of-the-art result, but the idea is worth rethinking.

Reference

[1] 变分推断(Variational Inference)最新进展简述
[2] 变分推断学习笔记
[3] ELBO 与 KL散度
[4] Stochastic Variational Inference
[5] 请解释下variational inference
[6] 变分推断中的ELBO(证据下界)

你可能感兴趣的:(paper,reading)