tensorflow保存加载多个模型

#保存加载过个模型时要注意必须指定Graph
class MLP(object):
    def __init__(self, id):
        if not os.path.exists('./' + id):
            os.makedirs('./' + id)
        self.id = id

        self.graph = tf.Graph()
        self.session_conf = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False)
        self.load_model()

    def init_net(self):
        # Placeholders for input, output and dropout
        self.input_x = tf.placeholder(tf.float32, [None, 1], name="input_x")
        self.input_y = tf.placeholder(tf.float32, [None, 1], name="input_y")

        with tf.name_scope('mlp1'):
            W = tf.Variable(tf.truncated_normal([1,50], stddev=0.1), name="W")
            b = tf.Variable(tf.constant(0.1, shape=[50]), name="b")
            self.mlp1 = tf.nn.xw_plus_b(self.input_x, W, b, name="xwb")

        with tf.name_scope('mlp2'):
            W1 = tf.Variable(tf.truncated_normal([50,1], stddev=0.1), name="W1")
            b1 = tf.Variable(tf.constant(0.1, shape=[1]), name="b1")
            self.mlp1 = tf.nn.xw_plus_b(self.mlp1, W1, b1, name="xwb1")
            self.prediction = tf.nn.sigmoid(self.mlp1)

        with tf.name_scope("loss"):
            losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.mlp1, labels=self.input_y)
            self.loss = tf.reduce_mean(losses)

        with tf.name_scope("optimizer"):
            self.global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(self.loss)
            self.train_op = optimizer.apply_gradients(grads_and_vars, global_step=self.global_step)

    def load_model(self):
        with self.graph.as_default():
            self.sess = tf.Session(graph=self.graph, config=self.session_conf)
            if os.path.exists('./' + self.id + '/model.meta'):
                self.init_net()
                self.saver = tf.train.Saver()
                self.saver.restore(self.sess, tf.train.latest_checkpoint('./' + self.id))
            else:
                self.init_net()
                self.sess.run(tf.global_variables_initializer())
                self.saver = tf.train.Saver()


    def train(self):
        print 'traning'
        with self.sess.as_default():
            for i in range(1000):
                x, y = generate_data(1000,self.id)
                loss,_ = self.sess.run([self.loss,self.train_op],feed_dict={self.input_x:x,self.input_y:y})
                x_test,y_test = generate_data(100,self.id)
                prediction = self.sess.run(self.prediction, feed_dict={self.input_x: x_test, self.input_y: y_test})
                acc = self.get_acc(prediction,y_test)
                print 'step:',i,'loss:',loss,'acc:',acc
            self.saver.save(self.sess, './' + self.id + '/model')

    def test(self):
        print 'testing'
        with self.sess.as_default():
            x_test, y_test = generate_data(1000,self.id)
            prediction = self.sess.run(self.prediction, feed_dict={self.input_x: x_test, self.input_y: y_test})
            acc = self.get_acc(prediction, y_test)
            print 'acc:', acc

你可能感兴趣的:(deeplearning,tensorflow)