tensorflow识别音频文件

该文章纯属转载

 

概述 随着深度学习的广泛应用和Tensoflow的开源,移动端的模型应用层出不穷。本文介绍了笔者在搭建过程中的一些心得,希望可以帮助到你们。 Mac端Tensorflow CPU版本的安装 如果你现在用的没有太好的GPU,可以安装CPU only的Tensorflow。Linux、Mac系统可以安装Tensorflow的python2和python3版本,Windows系统仅支持python3版本。 安装Tensorflow的依赖库bazel,这个后面要用来生成Tensorflow支持Android的jar包和so库。Mac下用brew安装命令:brew install bazel,或者根据bazel官方文档安装相应的版本; 用pip安装tensorflow的CPU only版本:pip install tensorflow; 验证Tensorflow安装是否成功:


import tensorflow as tf
#显示当前Tensorflow版本号
tf.__version__

生成jar包和so库 将github上的tensorflow下载到本地 修改tensorflow目录下的WORKSPACE,将其中的sdk和ndk路径改为本地对应路径,其中的sdk的版本号要≥23,ndk的版本号建议是12b(高版本的ndk在用bazel编译时会出现一些问题),build_tools_version版本根据自己实际情况更改: 
 根据下面的指令生成jar包和so库 
参考链接: 
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android


生成so库的命令:可以选择cpu版本号
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
  --crosstool_top=//external:android/crosstool \
  --host_crosstool_top=@bazel_tools//tools/cpp:toolchain\
  --cpu=armeabi-v7a

so库的位置:
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

生成jar包的命令:
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java

jar包的位置:
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

4. so库也可以选择Tensorflow官方提供的现成的文件,参考链接: 
http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/libtensorflow_inference.so/

Android端的搭建 将jar包放入app->libs目录下,并在build.gradle上添加依赖compile files('libs/libandroid_tensorflow_inference_java.jar'); 在src->main目录下新建文件夹jniLibs并将生成的so库放在该目录下; 将PC端训练的模型保存为pb模型


output_graph_def= \
graph_util.convert_variables_to_constants(sess, \
sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile("path/to/xxx.pb","wb") as f:
    f.write(output_graph_def.SerializeToString())

4.将pb文件放在src->main->assets目录下; 
5. 下面我们用Android代码解释一下Tensorflow在Android端的搭建


public class TensorFlowAudioClassifier implements Classifier{

    private static final String TAG = "TensorFlowAudioClassifier";

    // Only return this many results with at least this confidence.
    private static final int MAX_RESULTS = 3;
    private static final float THRESHOLD = 0.0f;

    // Config values.
    //输入节点的名称(不带后面的':0',只是input的名称,如'input')
    private String inputName;
    //输出节点的名称(通输入节点名称一样)
    private String outputName;
    //输入矩阵的大小(因为一般是方形矩阵,这儿是方形矩阵的size)
    private int inputSize;

    // Pre-allocated buffers.
    private Vector labels = new Vector();
    private float[] floatValues;
    private float[] outputs;
    private String[] outputNames;

    private TensorFlowInferenceInterface inferenceInterface;
    //这儿使用单例模式

    private TensorFlowAudioClassifier() {
    }

    /**
     * Initializes a native TensorFlow session for classifying images.
     *
     * @param assetManager  The asset manager to be used to load assets.
     * @param modelFilename The filepath of the model GraphDef protocol buffer.
     * @param labelFilename The filepath of label file for classes.
     * @param inputSize     The input size. A square image of inputSize x inputSize is assumed.
     * @param inputName     The label of the image input node.
     * @param outputName    The label of the output node.
     * @throws IOException
     */
    public static Classifier create(
            AssetManager assetManager,
            String modelFilename,
            String labelFilename,
            int inputSize,
            String inputName,
            String outputName)
            throws IOException {
        TensorFlowAudioClassifier c = new TensorFlowAudioClassifier();
        c.inputName = inputName;
        c.outputName = outputName;

        // Read the label names into memory.
        // TODO(andrewharp): make this handle non-assets.
        //获取label文件,后面可以用来构建bean
        String actualFilename = labelFilename.split("file:///android_asset/")[1];
        Log.i(TAG, "Reading labels from: " + actualFilename);
        BufferedReader br = null;
        br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
        String line;
        while ((line = br.readLine()) != null) {
            c.labels.add(line);
        }
        br.close();

        c.inferenceInterface = new TensorFlowInferenceInterface();
        if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
            throw new RuntimeException("TF initialization failed");
        }
        // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
        int numClasses =
                (int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
        Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);

        // Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
        // the placeholder node for input in the graphdef typically used does not specify a shape, so it
        // must be passed in as a parameter.
        c.inputSize = inputSize;

        // Pre-allocate buffers.
        c.outputNames = new String[]{outputName};
        c.floatValues = new float[inputSize * inputSize * 1];
        c.outputs = new float[numClasses];

        return c;
    }

    @Override
    // 识别过程
    public List recognizeAudio(String fileName) {
        // Log this method so that it can be analyzed with systrace.
        Trace.beginSection("recognizeAudio");

        Trace.beginSection("preprocessAudio");
        // Preprocess the audio data to normalized float based
        // on the provided parameters.
        // 将音频文件构建成输入数组,输入数组是一维float数组,所以在Tensorflow中的特征要转变为一维
        double[][] data = RFFT.inputData(fileName);
        for (int i = 0; i < data.length; ++i) {
            for (int j = 0 ; j < data[0].length ; ++j) {
                floatValues[i * 40 + j] = (float)data[i][j];
            }
        }
        Trace.endSection();

        // Copy the input data into TensorFlow.
        Trace.beginSection("fillNodeFloat");
        //将输入放入InferenceInterface
        inferenceInterface.fillNodeFloat(
                inputName, new int[]{1,40,40,1}, floatValues);
        Trace.endSection();

//      Trace.beginSection("fillNodeFloat");
//      inferenceInterface.fillNodeFloat(inputName2,new int[]{2,2},floatValues);
//      Trace.endSection();

        // Run the inference call.
        Trace.beginSection("runInference");
        // 运行模型
        inferenceInterface.runInference(outputNames);
        Trace.endSection();

        // Copy the output Tensor back into the output array.
        Trace.beginSection("readNodeFloat");
        // 获得输出的confidence到outputs数组里
        inferenceInterface.readNodeFloat(outputName, outputs);
        Trace.endSection();

        // Find the best classifications.
        // 用PriorityQueue获取top-3,这儿的Recognition来自于接口Classifier,是一个bean类
        PriorityQueue pq =
                new PriorityQueue(
                        3,
                        new Comparator() {
                            @Override
                            public int compare(Recognition lhs, Recognition rhs) {
                                // Intentionally reversed to put high confidence at the head of the queue.
                                return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                            }
                        });
        for (int i = 0; i < outputs.length; ++i) {
            if (outputs[i] > THRESHOLD) {
                pq.add(
        //构建bean类,参数是label,label name,confidence
                        new Recognition(
                                "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i]));
            }
        }
        final ArrayList recognitions = new ArrayList();
        int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
        if(recognitionsSize == 0){
            return null;
        }
        for (int i = 0; i < recognitionsSize; ++i) {
            recognitions.add(pq.poll());
        }
        Trace.endSection(); // "recognizeAudio"
        return recognitions;
    }

    @Override
    public void enableStatLogging(boolean debug) {
        inferenceInterface.enableStatLogging(debug);
    }

    @Override
    public String getStatString() {
        return inferenceInterface.getStatString();
    }

    @Override
    public void close() {
        inferenceInterface.close();
    }
}

我们成功地构建了用来识别分类的TensorFlowAudioClassifier,下面展示一下如何使用我们构建的类:


    public static TensorFlowAudioClassifier classifier;
    private static final String INPUT_NAME = "input";
    private static final String OUTPUT_NAME = "output";

    private static final String MODEL_FILE = "file:///android_asset/acoustic.pb";
    private static final String LABEL_FILE =
            "file:///android_asset/eventLabel.txt";
    private static int INPUT_SIZE = 40;
    try {
    // 获取分类器
            classifier = (TensorFlowAudioClassifier) TensorFlowAudioClassifier.create(
                    getAssets(),
                    MODEL_FILE,
                    LABEL_FILE,
                    INPUT_SIZE,
                    INPUT_NAME,
                    OUTPUT_NAME
            );
            // 识别对应Audio的类别
            Recognition result = classifier.recognizeAudio(fileName);
        } catch (IOException e) {
            e.printStackTrace();
        }

你可能感兴趣的:(tensorflow,人工智能,卷积神经网络)