原文链接
中文翻译
由于最近在做网络量化方面的工作,看了一下这个论文。
作者在github上开源了源码,结合论文看并不复杂。
我先总结一下,想要直接看结果的看官可以根据我的结论,判断自己是否要继续看下去。
这篇论文和源码的实现,因为权重量化与激活量化并没有很好的结合在一起,所以不适合于工程上实现,虽然精度的确没有下降很多,但是计算的速度难以提升。
这里只属于自己看完之后的一些总结,以及将其中算是核心的东西剥离出来。
个人感觉这个论文有骗科研经费的感觉。
进入主题吧。
对于CNN网络,只讨论前向过程的话,有两个方面的量化需要去做,一个就是权重值的量化,另一个就是激活值的量化。后向的话,还有对梯度值的量化。
我也就从这个两个方面来解析一下吧。
个人感觉,权重值量化是这篇论文的核心内容。论文中有很详细的描述,并且Caffe实现也算是非常仔细。
我根据其C++源码,将其中关键部分抽离出来,写成无第三方库依赖的Python代码,在后面会贴出来。
其过程类似于一个聚类的过程,对权重进行非线性的量化,可在对应的量化图中查询对应的值,并且对应的量化图是对称的。
#!/usr/bin/env python3
import copy
import math as m
from functools import reduce
class QuantInfo:
def __init__(self, sum, slice, level, n):
self.sum = sum
self.slice = slice
self.level = level
self.n = n
def half_level(self):
return (self.level + 1) // 2
def p(self, n):
if n == 0:
return self.slice[0] / self.n
else:
return (self.slice[n] - self.slice[n - 1]) / self.n
def wp(self, n):
if n == 0:
index = self.slice[0] - 1
return self.sum[index] / self.n
else:
index1 = self.slice[n] - 1
index2 = self.slice[n - 1] - 1
return (self.sum[index1] - self.sum[index2]) / self.n
def w(self, n):
if n == 0:
index = self.slice[0] - 1
return self.sum[index] / self.slice[0]
else:
index1 = self.slice[n] - 1
index2 = self.slice[n - 1] - 1
return (self.sum[index1] - self.sum[index2]) / (self.slice[n] - self.slice[n - 1])
def wEnt(self):
ent = 0
for i in range(self.level % 2, self.half_level()):
ent -= self.wp(i) * m.log(self.p(i) / 2)
return ent
def pEnt(self, n):
ent = 0
for i in range(max(n, self.level % 2), n+2):
ent -= self.wp(i) * m.log(self.p(i) / 2)
return ent
# end class
def recUpdate(
n,
info,
begin = 0,
end = 0):
if begin == 0 and end == 0:
if n == 0:
new_begin = 0
new_end = info.slice[1] - 1
else:
new_begin = info.slice[n-1]
new_end = info.slice[n+1] - 1
recUpdate(n, info, new_begin, new_end)
elif end - begin == 1:
info.slice[n] = begin
b_ent = info.pEnt(n)
info.slice[n] = end
e_ent = info.pEnt(n)
info.slice[n] = begin if b_ent > e_ent else end
else:
center = (begin + end) // 2
info.slice[n] = center
f_ent = info.pEnt(n)
info.slice[n] = center + 1
b_ent = info.pEnt(n)
if f_ent > b_ent:
recUpdate(n, info, begin, center)
else:
recUpdate(n, info, center, end)
# end recUpdate
def weight_quant_scale_forward(
half_level,
len_ch,
num_ch,
#scale_list,
weight_list,
quant_bdr,
quant_rep):
out = []
bin_out = []
for ch in range(num_ch):
for idx in range(len_ch):
cur_weight = weight_list[ch * len_ch + idx]
sign = 1 if cur_weight >= 0 else -1
value = cur_weight * sign
output_temp = 0
bin_index = 0
for j in range(half_level):
if value >= quant_bdr[j]:
output_temp = quant_rep[j]
bin_index = j
output_value = output_temp * sign
out.append(output_value)
bin_out.append(bin_index * sign)
return (out, bin_out)
#end weight_quant_scale_forward
def eh_weight_quant_update(
num_level,
num_ch,
len_ch,
in_weight,
slice):
half_level = (num_level+1) // 2;
n = len(in_weight)
#scale_list = []
#for i in range(num_ch):
# scale_list.append(1.0)
importance_list = [idx ** 2 for idx in in_weight]
importance_list.sort()
importance_sorted_list = importance_list
suffix_sum_list = []
for i in range(1, n+1):
suffix_sum_list.append(reduce(lambda x, y: x+ y, importance_sorted_list[0:i]))
#print(suffix_sum_list)
slice_vec = []
for i in range(half_level):
slice_vec.append(copy.deepcopy(slice[i]))
#print(slice_vec)
info = QuantInfo(
copy.deepcopy(importance_sorted_list),
slice_vec,
num_level,
n)
new_ent = info.wEnt()
prev_ent = 0.
while new_ent > prev_ent:
for i in range(half_level - 1):
recUpdate(i, info)
prev_ent = new_ent
new_ent = info.wEnt()
#print(slice_vec)
rep = [0 for i in range(half_level)]
bdr = [0 for i in range(half_level)]
for i in range(num_level % 2, half_level):
rep[i] = m.sqrt(info.w(i))
for i in range(half_level - 1):
bdr[i + 1] = m.sqrt(
importance_sorted_list[slice_vec[i]] -
importance_sorted_list[slice_vec[i] - 1])
print("[scale factor] rep is {}".format(rep))
print("[log distribution] bdr is {}".format(bdr))
out, bin_out = weight_quant_scale_forward(
half_level,
len_ch,
num_ch,
#scale_list,
weight_list,
bdr,
rep)
print("[bin ] {}".format(bin_out))
print("[real] {}".format(out))
#end eh_weight_quant_update
if __name__ == '__main__':
quant_level = 5
number_ch = 1
weight_list = [0.3, 1.4, 2.7, -3.1, 0.8, -1.1, 2.3, 0.0, 4.1]
length_ch = len(weight_list) // number_ch
half_level = (quant_level + 1) // 2
slice = [ int(len(weight_list)*(idx+1)/half_level) for idx in range(half_level)]
print("quant level is {}".format(quant_level))
print(weight_list)
eh_weight_quant_update(
quant_level,
number_ch,
length_ch,
weight_list,
slice)
激活值量化就做的比较粗糙了,使用的是位移的方式,而且是放在ReLU之后,这样做的话缺失了一般性,没有考虑卷积之后没有跟ReLU的情况(虽然,目前这种状况非常少),这样的话就需要有一位来表示原始数据的符号,这样就会有些浪费。
并且与权重值只能在还原回单精度的值之后才能计算,这算是这篇论文的痛点吧,这样在计算上并没有什么优势可言。
以下是我抽离出来的Python代码,进行了少量的修改,大家感兴趣可以和原始代码对比着看。
#!/usr/bin/env python3
import math as m
import copy
import sys
from functools import reduce
BASE_SIZE = 16
BIN_SIZE = 1024
def eh_histogram(
active_list,
offset):
hist = [0 for idx in range(BIN_SIZE)]
for index in range(len(active_list)):
cur_active = active_list[index]
#print("cur_active is {}".format(cur_active))
if cur_active > 0:
quant = m.floor(m.log2(cur_active) * BASE_SIZE) - offset
#print("quant is {}".format(quant))
if quant < 0:
hist[0] += 1
elif quant >= BIN_SIZE:
hist[BIN_SIZE - 1] += 1
else:
hist[quant] += 1
#print(hist)
return hist
# end eh_histogram
class LQInfo:
def __init__(self, num_level, num_weight, offset, hist_scan):
self.num_level = num_level
self.num_weight = num_weight
self.offset = offset
self.hist_scan = hist_scan
def p(self, idx, fst, step):
if idx == self.num_level - 1:
begin = fsr + (step >> 1) * (2 * idx - 3)
temp = self.num_weight - self.hist_scan[begin - 1]
return temp / self.num_weight
else:
begin = fsr + (step >> 1) * (2 * idx - 3)
end = fsr + (step >> 1) * (2 * idx - 1)
temp = self.hist_scan[end - 1] - self.hist_scan[begin - 1]
return temp /self.num_weight
def w(self, idx, fsr, step):
temp = (fsr + (idx - 1) * step + self.offset) / BIN_SIZE
return m.pow(2, temp)
def wEnt(self, fsr, step):
ent = 0
for i in range(1, self.num_level):
prob = self.p(i, fsr, step)
if prob > 0:
ent -= self.w(i, fsr, step) * prob * m.log(prob)
return ent
# end LQInfo
def WLQReLUForward2(
num_level,
active_list,
offset,
step,
train):
float_out = []
quant_out = []
sign_out = []
for index in range(0, len(active_list)):
oTemp = 0
in_data = active_list[index]
sign = 1
if in_data < 0:
sign = -1
in_data *= sign
sign_out.append(0 if sign > 0 else 1)
if in_data > 0:
temp = round((m.log2(in_data) * BASE_SIZE - offset) / float(step + 10e-10))
#print(temp)
#mod_idx = min(num_level - 2, temp)
# Reserve 2bit for isZero and isPositive number
mod_idx = min(num_level-1, temp)
quant_out.append(mod_idx)
if mod_idx < 0:
oTemp = sys.float_info.min if train == 0 else 0.
else:
temp = (offset + mod_idx * step) // BASE_SIZE
oTemp = m.pow(2, temp)
quant_out[index] = temp
float_out.append(sign * oTemp)
return (float_out, quant_out, sign_out)
if __name__ == '__main__':
# only first picture do this\
num_level = 4
offset = -640
active_list = [0.3, 1.4, 2.7, -3.1, 0.8, -14.1, 2.3, 0.1, 14.1]
#active_list = [1.317857, 0.317857, 2.530624, 1.530624, 0.294160, 3.294160, 10.317857]
print("active_list is {}".format(active_list))
if True:
hist = eh_histogram(
active_list,
offset)
#print(hist)
suffix_sum_list = []
for i in range(1, len(hist) + 1):
suffix_sum_list.append(reduce(lambda x, y: x+ y, hist[0:i]))
#print(suffix_sum_list)
min_offset = -1
max_offset = -1
max_value = -1
for i in range(BIN_SIZE):
if suffix_sum_list[i] == 0:
min_offset = i
if suffix_sum_list[i] > max_value:
max_offset = i
max_value = suffix_sum_list[i]
length = max_offset - min_offset + 1
#hist += min_offset
offset += min_offset
#print("length is {}".format(length))
#print("offset is {}".format(offset))
max_ent = 0
max_half_step = 0
max_fsr = 0
suffix_sum_index = min_offset + length - 1
info = LQInfo(
num_level,
suffix_sum_list[suffix_sum_index],
offset,
suffix_sum_list[min_offset:])
for half_step in range(1, 17):
end_temp = length - half_step * (2 * num_level - 5)
for fsr in range(40, end_temp):
ent = info.wEnt(fsr, 2 * half_step)
if ent > max_ent:
max_ent = ent
max_half_step = half_step
max_fsr = fsr
bdr = [0 for idx in range(256)]
bdr[0] = bdr[num_level] = 0;
for i in range(1, num_level):
temp1 = (max_fsr + offset + 2 * (i - 1) * max_half_step)
bdr[num_level + i] = m.pow(2, temp1 / BASE_SIZE)
temp2 = (max_fsr + offset + (2 * i - 1) * max_half_step)
bdr[i] = m.pow(2, temp2 / BASE_SIZE)
bdr[num_level] = max_fsr + offset
bdr[0] = 2 * max_half_step
float_out, quant_out, sign_out = WLQReLUForward2(
num_level,
active_list,
bdr[num_level],
bdr[0],
1) # mean test
print("float output is {}".format(float_out))
print("quant output is {}".format(quant_out))
print("sign output is {}".format(sign_out))
print("offset is {}".format(bdr[num_level]))
print("step is {}".format(bdr[0]))