tensorflow1.14.0代码适配tensorflow2.5.0遇到的坑

tensorflow1.1.4.0版本代码适配tensorflow2.5环境修改的代码:

import tensorflow as tf
from tensorflow.python.keras.backend import set_session

tf.compat.v1.disable_eager_execution()  # TensorFlow2.X下关闭eager mode
sess = tf.compat.v1.Session()
graph = tf.compat.v1.get_default_graph()

set_session(sess)
with graph.as_default():
    set_session(sess)
    x, s = self.tokenizer.encode(t, maxlen=self.maxlen)

with graph.as_default():
    set_session(sess)
    Z = self.encoder.predict([X, S])

上述代码只截取了相关代码

在测试环境运行时,上述代码在tf2.5.0环境下,是正常运行的。但是到线上环境,和其他相关服务部署到一起后,就不行。在下面这一步卡住,同时也导致其他模块加载模型出现卡住现象。

from bert4keras.models import build_transformer_model
bert = build_transformer_model(config_path, checkpoint_path, model=model_type)

奇怪的是,在容器内部的python环境中,逐行执行上面加载预训练模型操作,是可以正常执行的。

各种排查,无果,也没任何报错。

最后推理,代码应该是没问题,流程本地也都多次测试,也没问题,从影响面来看,不只是当前任务进程卡住,也影响到了有同样加载预训练模型的其他模块流程,那应该是共用的某些部分产生了影响。最后带着这样的思路,排查代码,发现是下面几行代码影响到了tensorflow环境所致:

tf.compat.v1.disable_eager_execution()  # TensorFlow2.X下关闭eager mode

 而这行代码就是当初为了把框架从tf1.14升级到tf2添加的。

故针对当前代码,重新适配tf2,把session和graph相关的都直接删除。如下:

import tensorflow as tf

x, s = self.tokenizer.encode(t, maxlen=self.maxlen)
Z = self.encoder.predict([X, S])

经测试,可行。 

你可能感兴趣的:(python,tensorflow,深度学习,机器学习)