个人 解析“Weighted-Entropy-based Quantization”

原文链接

中文翻译

由于最近在做网络量化方面的工作,看了一下这个论文。
作者在github上开源了源码,结合论文看并不复杂。

我先总结一下,想要直接看结果的看官可以根据我的结论,判断自己是否要继续看下去。
这篇论文和源码的实现,因为权重量化与激活量化并没有很好的结合在一起,所以不适合于工程上实现,虽然精度的确没有下降很多,但是计算的速度难以提升。

这里只属于自己看完之后的一些总结,以及将其中算是核心的东西剥离出来。
个人感觉这个论文有骗科研经费的感觉。

进入主题吧。

对于CNN网络,只讨论前向过程的话,有两个方面的量化需要去做,一个就是权重值的量化,另一个就是激活值的量化。后向的话,还有对梯度值的量化。

我也就从这个两个方面来解析一下吧。

权重值量化

个人感觉,权重值量化是这篇论文的核心内容。论文中有很详细的描述,并且Caffe实现也算是非常仔细。
我根据其C++源码,将其中关键部分抽离出来,写成无第三方库依赖的Python代码,在后面会贴出来。
其过程类似于一个聚类的过程,对权重进行非线性的量化,可在对应的量化图中查询对应的值,并且对应的量化图是对称的。

#!/usr/bin/env python3

import copy
import math as m
from functools import reduce

class QuantInfo:

  def __init__(self, sum, slice, level, n):
    self.sum = sum
    self.slice = slice
    self.level = level
    self.n = n

  def half_level(self):
    return (self.level + 1) // 2

  def p(self, n):
    if n == 0:
      return self.slice[0] / self.n
    else:
      return (self.slice[n] - self.slice[n - 1]) / self.n

  def wp(self, n):
    if n == 0:
      index = self.slice[0] - 1
      return self.sum[index] / self.n
    else:
      index1 = self.slice[n] - 1
      index2 = self.slice[n - 1] - 1
      return (self.sum[index1] - self.sum[index2]) / self.n

  def w(self, n):
    if n == 0:
      index = self.slice[0] - 1
      return self.sum[index] / self.slice[0]
    else:
      index1 = self.slice[n] - 1
      index2 = self.slice[n - 1] - 1
      return (self.sum[index1] - self.sum[index2]) / (self.slice[n] - self.slice[n - 1])

  def wEnt(self):
    ent = 0
    for i in range(self.level % 2, self.half_level()):
      ent -= self.wp(i) * m.log(self.p(i) / 2)

    return ent

  def pEnt(self, n):
    ent = 0
    for i in range(max(n, self.level % 2), n+2):
      ent -= self.wp(i) * m.log(self.p(i) / 2)

    return ent
# end class

def recUpdate(
  n,
  info,
  begin = 0,
  end = 0):

  if begin == 0 and end == 0:
    if n == 0:
      new_begin = 0
      new_end = info.slice[1] - 1
    else:
      new_begin = info.slice[n-1]
      new_end = info.slice[n+1] - 1

    recUpdate(n, info, new_begin, new_end)

  elif end - begin == 1:
    info.slice[n] = begin
    b_ent = info.pEnt(n)
    info.slice[n] = end
    e_ent = info.pEnt(n)

    info.slice[n] = begin if b_ent > e_ent else end
  else:
    center = (begin + end) // 2
    info.slice[n] = center
    f_ent = info.pEnt(n)
    info.slice[n] = center + 1
    b_ent = info.pEnt(n)

    if f_ent > b_ent:
      recUpdate(n, info, begin, center)
    else:
      recUpdate(n, info, center, end)
# end recUpdate

def weight_quant_scale_forward(
  half_level,
  len_ch,
  num_ch,
  #scale_list,
  weight_list,
  quant_bdr,
  quant_rep):

  out = []
  bin_out = []
  for ch in range(num_ch):
    for idx in range(len_ch):
      cur_weight = weight_list[ch * len_ch + idx]
      sign = 1 if cur_weight >= 0 else -1
      value = cur_weight * sign

      output_temp = 0
      bin_index = 0
      for j in range(half_level):
        if value >= quant_bdr[j]:
          output_temp = quant_rep[j]
          bin_index = j

      output_value = output_temp * sign
      out.append(output_value)
      bin_out.append(bin_index * sign)

  return (out, bin_out)

#end weight_quant_scale_forward

def eh_weight_quant_update(
  num_level,
  num_ch,
  len_ch,
  in_weight,
  slice):

  half_level = (num_level+1) // 2;
  n = len(in_weight)

  #scale_list = []
  #for i in range(num_ch):
  #  scale_list.append(1.0)

  importance_list = [idx ** 2 for idx in in_weight]

  importance_list.sort()
  importance_sorted_list = importance_list

  suffix_sum_list = []
  for i in range(1, n+1):
    suffix_sum_list.append(reduce(lambda x, y: x+ y, importance_sorted_list[0:i]))

  #print(suffix_sum_list)

  slice_vec = []
  for i in range(half_level):
    slice_vec.append(copy.deepcopy(slice[i]))

  #print(slice_vec)

  info = QuantInfo(
    copy.deepcopy(importance_sorted_list),
    slice_vec,
    num_level,
    n)

  new_ent = info.wEnt()
  prev_ent = 0.

  while new_ent > prev_ent:
    for i in range(half_level - 1):
      recUpdate(i, info)
    prev_ent = new_ent
    new_ent = info.wEnt()

  #print(slice_vec)

  rep = [0 for i in range(half_level)]
  bdr = [0 for i in range(half_level)]

  for i in range(num_level % 2, half_level):
    rep[i] = m.sqrt(info.w(i))

  for i in range(half_level - 1):
    bdr[i + 1] = m.sqrt(
      importance_sorted_list[slice_vec[i]] - 
      importance_sorted_list[slice_vec[i] - 1])

  print("[scale factor]     rep is {}".format(rep))
  print("[log distribution] bdr is {}".format(bdr))

  out, bin_out = weight_quant_scale_forward(
    half_level,
    len_ch,
    num_ch,
    #scale_list,
    weight_list,
    bdr,
    rep)

  print("[bin ] {}".format(bin_out))
  print("[real] {}".format(out))

#end eh_weight_quant_update

if __name__ == '__main__':

  quant_level = 5
  number_ch = 1
  weight_list = [0.3, 1.4, 2.7, -3.1, 0.8, -1.1, 2.3, 0.0, 4.1]
  length_ch = len(weight_list) // number_ch
  half_level = (quant_level + 1) // 2

  slice = [ int(len(weight_list)*(idx+1)/half_level) for idx in range(half_level)]

  print("quant level is {}".format(quant_level))
  print(weight_list)


  eh_weight_quant_update(
    quant_level,
    number_ch,
    length_ch,
    weight_list,
    slice)

激活值量化

激活值量化就做的比较粗糙了,使用的是位移的方式,而且是放在ReLU之后,这样做的话缺失了一般性,没有考虑卷积之后没有跟ReLU的情况(虽然,目前这种状况非常少),这样的话就需要有一位来表示原始数据的符号,这样就会有些浪费。
并且与权重值只能在还原回单精度的值之后才能计算,这算是这篇论文的痛点吧,这样在计算上并没有什么优势可言。
以下是我抽离出来的Python代码,进行了少量的修改,大家感兴趣可以和原始代码对比着看。

#!/usr/bin/env python3

import math as m
import copy
import sys
from functools import reduce

BASE_SIZE = 16
BIN_SIZE  = 1024

def eh_histogram(
  active_list,
  offset):

  hist = [0 for idx in range(BIN_SIZE)]

  for index in range(len(active_list)):

    cur_active = active_list[index]
    #print("cur_active is {}".format(cur_active))

    if cur_active > 0:

      quant = m.floor(m.log2(cur_active) * BASE_SIZE) - offset
      #print("quant is {}".format(quant))

      if quant < 0:
        hist[0] += 1
      elif quant >= BIN_SIZE:
        hist[BIN_SIZE - 1] += 1
      else:
        hist[quant] += 1

    #print(hist)

  return hist

# end eh_histogram

class LQInfo:

  def __init__(self, num_level, num_weight, offset, hist_scan):
    self.num_level = num_level
    self.num_weight = num_weight
    self.offset = offset
    self.hist_scan = hist_scan

  def p(self, idx, fst, step):

    if idx == self.num_level - 1:
      begin = fsr + (step >> 1) * (2 * idx - 3)
      temp = self.num_weight - self.hist_scan[begin - 1]
      return temp / self.num_weight
    else:
      begin = fsr + (step >> 1) * (2 * idx - 3)
      end = fsr + (step >> 1) * (2 * idx - 1)
      temp = self.hist_scan[end - 1] - self.hist_scan[begin - 1]
      return temp /self.num_weight

  def w(self, idx, fsr, step):
    temp = (fsr + (idx - 1) * step + self.offset) / BIN_SIZE
    return m.pow(2, temp)

  def wEnt(self, fsr, step):

    ent = 0

    for i in range(1, self.num_level):
      prob = self.p(i, fsr, step)

      if prob > 0:
        ent -= self.w(i, fsr, step) * prob * m.log(prob)

    return ent

# end LQInfo

def WLQReLUForward2(
  num_level,
  active_list,
  offset,
  step,
  train):

  float_out = []
  quant_out = []
  sign_out = []

  for index in range(0, len(active_list)):

    oTemp = 0
    in_data = active_list[index]

    sign = 1
    if in_data < 0:
      sign = -1

    in_data *= sign

    sign_out.append(0 if sign > 0 else 1)

    if in_data > 0:

      temp = round((m.log2(in_data) * BASE_SIZE - offset) / float(step + 10e-10))
      #print(temp)
      #mod_idx = min(num_level - 2, temp)
      # Reserve 2bit for isZero and isPositive number
      mod_idx = min(num_level-1, temp)
      quant_out.append(mod_idx)

      if mod_idx < 0:
        oTemp =  sys.float_info.min if train == 0 else 0.
      else:
        temp = (offset + mod_idx * step) // BASE_SIZE
        oTemp = m.pow(2, temp)
        quant_out[index] = temp

    float_out.append(sign * oTemp)

  return (float_out, quant_out, sign_out)

if __name__ == '__main__':
  # only first picture do this\

  num_level = 4
  offset = -640
  active_list = [0.3, 1.4, 2.7, -3.1, 0.8, -14.1, 2.3, 0.1, 14.1]
  #active_list = [1.317857, 0.317857, 2.530624, 1.530624, 0.294160, 3.294160, 10.317857]
  print("active_list is {}".format(active_list))

  if True:

    hist = eh_histogram(
      active_list,
      offset)

    #print(hist)

    suffix_sum_list = []
    for i in range(1, len(hist) + 1):
      suffix_sum_list.append(reduce(lambda x, y: x+ y, hist[0:i]))

    #print(suffix_sum_list)

    min_offset = -1
    max_offset = -1
    max_value = -1

    for i in range(BIN_SIZE):
      if suffix_sum_list[i] == 0:
        min_offset = i

      if suffix_sum_list[i] > max_value:
        max_offset = i
        max_value = suffix_sum_list[i]

    length = max_offset - min_offset + 1
    #hist += min_offset
    offset += min_offset
    #print("length is {}".format(length))
    #print("offset is {}".format(offset))

    max_ent = 0
    max_half_step = 0
    max_fsr = 0

    suffix_sum_index = min_offset + length - 1

    info = LQInfo(
            num_level, 
            suffix_sum_list[suffix_sum_index], 
            offset, 
            suffix_sum_list[min_offset:])

    for half_step in range(1, 17):
      end_temp = length - half_step * (2 * num_level - 5)
      for fsr in range(40, end_temp):
        ent = info.wEnt(fsr, 2 * half_step)

        if ent > max_ent:
          max_ent = ent
          max_half_step = half_step
          max_fsr = fsr

    bdr = [0 for idx in range(256)]
    bdr[0] = bdr[num_level] = 0;

    for i in range(1, num_level):
      temp1 = (max_fsr + offset + 2 * (i - 1) * max_half_step)
      bdr[num_level + i] = m.pow(2, temp1 / BASE_SIZE)
      temp2 = (max_fsr + offset + (2 * i - 1) * max_half_step)
      bdr[i] = m.pow(2, temp2 / BASE_SIZE)

    bdr[num_level] = max_fsr + offset
    bdr[0] = 2 * max_half_step

    float_out, quant_out, sign_out = WLQReLUForward2(
      num_level,
      active_list,
      bdr[num_level],
      bdr[0],
      1) # mean test

    print("float output is {}".format(float_out))
    print("quant output is {}".format(quant_out))
    print("sign  output is {}".format(sign_out))
    print("offset is {}".format(bdr[num_level]))
    print("step is {}".format(bdr[0]))

你可能感兴趣的:(前端移植)