腾讯PocketFlow模型压缩与加速实现

1. 开发环境:

Ubuntu16.04

Python3.6

tensorflow1.12

 

2. 安装

PocketFlow github地址:

官网教程:https://pocketflow.github.io/installation/

 

官网提供了三种PocketFlow部署方法,分别为:本地、docker、seven,这里我使用的是local开发方法。安装步骤:

2.1 拷贝工程到本地

$ git clone https://github.com/Tencent/PocketFlow.git

2.2 创建配置路径

进入PocketFlow主目录,拷贝path.conf.template并重命名为path.conf,修改里面的数据集路径,这里我用的是cifar10来做图片分类。

# data files
data_hdfs_host = None
data_dir_local_cifar10 = /home/liguiyuan/deep_learning/data/cifar-10-batches-bin
data_dir_hdfs_cifar10 = None
data_dir_seven_cifar10 = None
data_dir_docker_cifar10 = /opt/ml/data  # DO NOT EDIT
data_dir_local_ilsvrc12 = None
data_dir_hdfs_ilsvrc12 = None
data_dir_seven_ilsvrc12 = None
data_dir_docker_ilsvrc12 = /opt/ml/data  # DO NOT EDIT

# model files
model_http_url = https://api.ai.tencent.com/pocketflow

2.3 准备预训练模型 

下载你需要的预训练模型,可以从PocketFlow的list地址下载,地址为:https://api.ai.tencent.com/pocketflow/list.html

然后把下载好的文件解压为 ./PocketFlow/models子目录。

 

3. 训练模型

3.1 训练无压缩网络

# local mode with 1 GPU
$ ./scripts/run_local.sh nets/resnet_at_cifar10_run.py

# docker mode with 8 GPUs
$ ./scripts/run_docker.sh nets/resnet_at_cifar10_run.py -n=8

# seven mode with 8 GPUs
$ ./scripts/run_seven.sh nets/resnet_at_cifar10_run.py -n=8

 -n=8 表示GPU的数量,根据自己的配置进行调整,生成的模型在models文件夹了,最后看到loss和acc效果分别如下:

INFO:tensorflow:loss = 4.4453e-01
INFO:tensorflow:accuracy = 9.2500e-01

 

3.2 训练压缩的模型

  • 通用压缩模型

使用的是discrimination-aware channel pruning(DCP)算法。

# local mode with 1 GPU
$ ./scripts/run_local.sh nets/resnet_at_cifar10_run.py \
    --learner dis-chn-pruned

# docker mode with 8 GPUs
$ ./scripts/run_docker.sh nets/resnet_at_cifar10_run.py -n=8 \
    --learner dis-chn-pruned

# seven mode with 8 GPUs
$ ./scripts/run_seven.sh nets/resnet_at_cifar10_run.py -n=8 \
    --learner dis-chn-pruned

生成的DCP算法压缩模型在models_dcp目录下,默认的dcp_prune_ratio剪枝率为0.5

  • 指定自己的剪枝率来训练
$ ./scripts/run_local.sh nets/resnet_at_cifar10_run.py \
    --learner dis-chn-pruned \
    --enbl_dst \
    --dcp_prune_ratio 0.75

剪去了3/4,即卷积层网络规模变为原来的1/4。剪枝后的模型生成在models_dcp_eval目录下,我们看到loss: 9.63+,acc: 0.846,对比没剪枝之前,发现变换还是挺大的,loss变大很多,acc下降了约8个百分点,毕竟减去了3/4的网络,效果如下:

腾讯PocketFlow模型压缩与加速实现_第1张图片

官网还有很多配置选项,建议去看官方的配置选项详解。

 

4.导出为TensorFlow Lite模型

转换 checkpoint 文件为  *.tflite 文件,来部署到移动端设备:

# convert checkpoint files into a *.tflite model
$ python tools/conversion/export_pb_tflite_models.py \
    --model_dir models_dcp_eval

报了一个错误:

INFO:tensorflow:input details: [{'name': 'net_input', 'index': 0, 'shape': array([ 1, 32, 32,  3], dtype=int32), 'dtype': , 'quantization': (0.0, 0)}]
INFO:tensorflow:output details: [{'name': 'net_output', 'index': 1, 'shape': array([ 1, 10], dtype=int32), 'dtype': , 'quantization': (0.0, 0)}]
Traceback (most recent call last):
  File "tools/conversion/export_pb_tflite_models.py", line 383, in main
    export_pb_tflite_model(net, meta_path, pb_path, tflite_path)
  File "tools/conversion/export_pb_tflite_models.py", line 356, in export_pb_tflite_model
    test_tflite_model(tflite_path, net['input_data'])
  File "tools/conversion/export_pb_tflite_models.py", line 289, in test_tflite_model
    interpreter.set_tensor(input_details[0]['index'], net_input_data)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/lite/python/interpreter.py", line 156, in set_tensor
    self._interpreter.SetTensor(tensor_index, value)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/lite/python/interpreter_wrapper/tensorflow_wrap_interpreter_wrapper.py", line 133, in SetTensor
    return _tensorflow_wrap_interpreter_wrapper.InterpreterWrapper_SetTensor(self, i, value)
ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 0

是net_input的输入数据类型出问题了,把 int32改为np.float32类型就可以了。

修改tools/conversion/export_pb_tflite_models.py 289行

interpreter.set_tensor(input_details[0]['index'], net_input_data)
# 改为:
interpreter.set_tensor(input_details[0]['index'], net_input_data.astype(np.float32))

可以看到在PocketFlow/models_dcp_eval/目录下生成 model.pb  model.tflite两个文件,文件都很小,用于部署移动端设备最合适了。

 

5. 部署到移动端

TODO

 

 

你可能感兴趣的:(模型压缩加速,AI)