Bank-Balanced剪枝算法的MATLAB实现

文章目录

  • 前言
  • 一、Bank-Balanced剪枝算法简介
  • 二、算法的思想
  • 三、算法描述
  • 四、MATLAB实现
  • 五、简单的测试
  • 六、在具体网络中的应用
    • 1、准备工作
    • 2、剪枝
    • 3、再训练
  • 总结


前言

最近在做基于FPGA的LSTM加速,在架构的选择上犹豫不定,直到看到知乎上一位赛灵思工程师写的文章:稀疏LSTM硬件架构,豁然开朗,决定参考这个架构来做,在这里记录一下做的过程。
这篇博客写一下文章讲的剪枝算法Bank-Balanced剪枝算法,并用MATLAB实现。
关键的参考文章:efficient-and-effective-sparse-lstm-on-fpga-with-bank-balanced-sparsity


一、Bank-Balanced剪枝算法简介

这个剪枝算法是面向硬件加速的。
在这之前常用的剪枝方法有细粒度剪枝和粗粒度剪枝。其中,细粒度剪枝会得到一个非结构化的稀疏矩阵(也就是0的分布是不均匀的),将其实施到硬件上时会有负载不均衡的问题;粗粒度剪枝可以得到结构化的稀疏矩阵,但是对精度的影响会比较大。

而Bank-Balanced剪枝兼顾了负载的均衡性与精度,并且可以很方便地实现行之间的并行运算与行内的并行运算。

顺便提一下,使用细粒度剪枝然后将矩阵分块重排也是可以做到一定程度上的负载均衡的,这也是我之前打算使用的方法。有一篇硕士论文用到了这种方法:LSTM硬件加速器的运算单元优化。

二、算法的思想

算法的思想很简单,就是将矩阵的每一行均匀分成N个Bank,然后对每个Bank进行相同稀疏度的细粒度剪枝。
但是剪枝的结果很漂亮:每行的非0元素数都是相同的,并且行内的每个Bank的非0元素数也是相同的。
这对硬件实施时的负载均衡十分有利。并且这样的剪枝方式相比于粗粒度剪枝降低了较大的权值被剪掉的可能性。

三、算法描述

可以看原文中给出的算法描述:

Bank-Balanced剪枝算法的MATLAB实现_第1张图片

四、MATLAB实现

我在实现该算法时对上述流程稍微做了下修改,将利用阈值剪枝改成了用排序后的索引剪枝。举个极端的例子,一个Bank内的所有元素都是相同的,那怎么根据阈值进行剪枝呢?(于是就变成了Top-k剪枝)

function Mp = BankBalancedPruning(M, BankNum, k)%k是每个bank保留的数据数目

    Mp = zeros(size(M));
    row_num = size(M, 1);
    col_num = size(M, 2);%col_num应当是BankNum的整数倍,以保证分块的均匀
    step = col_num / BankNum;%每个bank的长度
    index = 1 : step : col_num;%每个bank第一个数字的位置

    for i = 1:row_num
        for j = index
            bank = M(i, j : j + step - 1);%暂存1个bank的内容
            [~, sort_index] = sort(abs(bank), 'descend');%降序排列
            keep_index = sort_index(1 : k);%高k位的相对索引
            Mp(i, j + keep_index - 1) = M(i, j + keep_index - 1);%保留
        end
    end

end

五、简单的测试

测试程序:

M = rand(4, 8);
BankNum = 2;
k= 1;
Mp = BankBalancedPruning(M, BankNum, k);

即原始矩阵维度为4x8,每行分为两个Bank,每个Bank保留一个元素。
运行结果如下:
原始矩阵:

Bank-Balanced剪枝算法的MATLAB实现_第2张图片
剪枝后:

Bank-Balanced剪枝算法的MATLAB实现_第3张图片
每行均匀分为两个Bank,每个Bank保留了1个最大值,与预期相符。

六、在具体网络中的应用

1、准备工作

第一,要有一个训练好的网络:使用MATLAB的trainNetwork设计一个简单的LSTM神经网络
第二,我们需要能编辑这个网络:MATLAB手动修改神经网络权值的方法

2、剪枝

分为三步:加载、修改、保存。

clear
clc

load('.\net_data\net')%加载一个训练好的网络
load('.\mnist_data_mat\XTest')%加载测试数据
load('.\mnist_data_mat\YTest')

%LSTM
RecurrentWeights = net.Layers(2,1).RecurrentWeights;%加载
%128个元素分为8个Bank,每个Bank保留2个
RecurrentWeightsSp = BankBalancedPruning(RecurrentWeights, 8, 2);%修改

InputWeights = net.Layers(2,1).InputWeights;%加载
%28个元素分为4个Bank,每个Bank保留2个
InputWeightsSp = BankBalancedPruning(InputWeights, 4, 2);%修改

%FC
Weights = net.Layers(3,1).Weights;%加载
%128个元素分为8个Bank,每个Bank保留1个
WeightsSp = BankBalancedPruning(Weights, 8, 1);%修改

modify_able_net = net.saveobj;%保存
modify_able_net.Layers(2,1).RecurrentWeights = single(RecurrentWeightsSp);
modify_able_net.Layers(2,1).InputWeights = single(InputWeightsSp);
modify_able_net.Layers(3,1).Weights = single(WeightsSp);

Modified_net = net.loadobj(modify_able_net);

Y_pred = classify(Modified_net, XTest);%计算准确度
accy = sum(Y_pred == YTest) / length(YTest);

剪枝的稀疏度如下:

lstm层InputWeights lstm层RecurrentWeights 全连接层Weights
71.4% 87.5% 93.8%

剪枝后的准确率降低到了33.92%,显然是不能接受的。

3、再训练

参照韩松的著名论文Learning both Weights and Connections for Efficient Neural Networks ,使用再训练的方式可以恢复精度。示意图:

Bank-Balanced剪枝算法的MATLAB实现_第4张图片
即进行剪枝-训练-剪枝的迭代。在本例中,迭代过程如下:

Bank-Balanced剪枝算法的MATLAB实现_第5张图片
经过27次迭代,准确率提高到了97.96%,同时保持了与表格中相同的稀疏度。
剪枝后的稀疏LSTM也放到资源里了,和mnist数据集放在一个压缩包里。
对于这个模型而言,准确度还可以提,稀疏度也可以提,这里只是做个演示。

总结

简单介绍了Bank-Balanced剪枝算法,并使用MATLAB实现。在手写识别任务中,用这种剪枝算法实现了稀疏度超过70%的LSTM模型,同时相比于稠密LSTM模型没有精度损失。

你可能感兴趣的:(硬件加速,机器学习,剪枝,matlab,fpga)