Java调用tensorflow模型

注意、注意、注意:版本很重要,本文使用的版本是tensorflow1.13.1,其他版本很可能不成功。

一、模型示例代码

    vocab_dim = 128
    batch_size = 64
    n_epoch = 5
    pad_sequences, labels_index, word_counts,n_classes=load_data()
    x_train,x_test,y_train,y_test=train_test_split(np.array(pad_sequences),np.array(labels_index),test_size=0.1,random_state=1)
    print("训练集shape: ", np.shape(x_train))
    print("测试集shape: ", np.shape(x_test))

    print('创建模型...')
    inputs=tf.keras.Input(shape=(25,),name='inputs')
    x=tf.keras.layers.Embedding(input_dim=word_counts+1,
                        output_dim=vocab_dim,
                        mask_zero=True,
                        input_length=MAX_LEN,trainable=True)(inputs)
    x=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=256,dropout=0.25,recurrent_dropout=0.15))(x)
    outputs=tf.keras.layers.Dense(n_classes,activation='softmax',name='outputs')(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    model.summary()
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])

    print("训练...")
    middle_weights='./weights/textrnn_weights.{epoch:02d}-{val_loss:.3f}.hdf5'
    mc_middle = tf.keras.callbacks.ModelCheckpoint(middle_weights, monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto')
    model.fit(x_train, y_train, batch_size=batch_size, epochs=1,class_weight = 'auto',validation_data=(x_test, y_test),verbose=1)
    save_model_for_production(model,'1','./export_model')

二、tensorflow模型存储为.pb格式

注意:model变量必须是全局变量,否则调用save_model_for_production方法会报session has closed的错

def save_model_for_production(model, version='1', path='prod_models'):
    tf.keras.backend.set_learning_phase(1)
    if not os.path.exists(path):
        os.mkdir(path)
    export_path = os.path.join(
        tf.compat.as_bytes(path),
        tf.compat.as_bytes(version))
    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)

    model_input = tf.compat.v1.saved_model.build_tensor_info(model.input)
    model_output = tf.compat.v1.saved_model.build_tensor_info(model.output)

    prediction_signature = (
        tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs={'inputs': model_input},
            outputs={'outputs': model_output},
            method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME))

    with tf.compat.v1.keras.backend.get_session() as sess:
        builder.add_meta_graph_and_variables(
            sess=sess, tags=[tf.compat.v1.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict':prediction_signature,
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature
            })
        builder.save()

三、修改maven中的pom.xml部分

    
        
            org.tensorflow
            libtensorflow
            1.13.1
        
        
            org.tensorflow
            tensorflow
            1.13.1
        
        
            org.tensorflow
            libtensorflow_jni
            1.13.1
        
    

三、java调tensorflow模型代码

import org.tensorflow.*;
import java.util.List;

/*
-----------------------------------------------
# @Time    : 2019/10/17 14:05
# @Author  : Dong.Wang
# @File    : Test.java
# @ProjectName: javaTensorflow
------------------------------------------------

# @Brief:
*/

public class Test {
    public static void main(String[] args) {
        SavedModelBundle b = SavedModelBundle.load("./src/main/resources/export_model", "serve");
        Session tfSession = b.session();
        Operation operationPredict = b.graph().operation("outputs/Softmax");   //要执行的op
        Output output = new Output(operationPredict, 0);
        float[][] a = new float[1][25];
        a[0] = new float[]{0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,38,508,980,35,140};
        Tensor input_x = Tensor.create(a);
        List> out = tfSession.runner().feed("inputs", input_x).fetch(output).run();
        for (Tensor s : out) {
            float[][] t = new float[1][562];
            s.copyTo(t);
            for (float i : t[0])
                System.out.println(i);
        }
    }
}

注意、注意、注意:输出操作的名称是"outputs/Softmax",虽然在python训练代码里有设置name="outputs":

outputs=tf.keras.layers.Dense(n_classes,activation='softmax',name='outputs')(x)

但是keras还是会自动生成一个name,如果这个设置错误会报空指针的错误。建议可以在python的训练代码debug一下,输出model.output.name,就可以看到输出层的名称是什么。

你可能感兴趣的:(深度学习)