tensorflow加载多个固化pb模型

使用tensorflow加载多个pb模型时,会引起变量冲突,解决方法按照如下方法加载模型可解决:

class Model:
    def __init__(self, model_file):
        self.graph = tf.Graph()
        self.graph_def = tf.GraphDef()
        with gfile.FastGFile(model_file, 'rb') as f:
            self.graph_def.ParseFromString(f.read())
        with self.graph.as_default():
            tf.import_graph_def(self.graph_def, name='')

        self.sess = tf.Session(graph=self.graph, config=config)

    def predict(self, images: list):
        output_node = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[-1].name)
        input_x = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[0].name)

        w = input_x.shape[1]
        h = input_x.shape[2]
        data = []

        for img in images:
            img = img.resize((w, h))
            img = np.array(img).astype(float)
            data.append(img)

        feed = {input_x: data}
        out = self.sess.run(output_node, feed)
        return out

由于之前没有定义输入、输出节点name,故采用节点索引的方式获取输入、输出节点tensor

你可能感兴趣的:(Machine,Learning)