#保存加载过个模型时要注意必须指定Graph
class MLP(object):
def __init__(self, id):
if not os.path.exists('./' + id):
os.makedirs('./' + id)
self.id = id
self.graph = tf.Graph()
self.session_conf = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False)
self.load_model()
def init_net(self):
# Placeholders for input, output and dropout
self.input_x = tf.placeholder(tf.float32, [None, 1], name="input_x")
self.input_y = tf.placeholder(tf.float32, [None, 1], name="input_y")
with tf.name_scope('mlp1'):
W = tf.Variable(tf.truncated_normal([1,50], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[50]), name="b")
self.mlp1 = tf.nn.xw_plus_b(self.input_x, W, b, name="xwb")
with tf.name_scope('mlp2'):
W1 = tf.Variable(tf.truncated_normal([50,1], stddev=0.1), name="W1")
b1 = tf.Variable(tf.constant(0.1, shape=[1]), name="b1")
self.mlp1 = tf.nn.xw_plus_b(self.mlp1, W1, b1, name="xwb1")
self.prediction = tf.nn.sigmoid(self.mlp1)
with tf.name_scope("loss"):
losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.mlp1, labels=self.input_y)
self.loss = tf.reduce_mean(losses)
with tf.name_scope("optimizer"):
self.global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.AdamOptimizer(1e-3)
grads_and_vars = optimizer.compute_gradients(self.loss)
self.train_op = optimizer.apply_gradients(grads_and_vars, global_step=self.global_step)
def load_model(self):
with self.graph.as_default():
self.sess = tf.Session(graph=self.graph, config=self.session_conf)
if os.path.exists('./' + self.id + '/model.meta'):
self.init_net()
self.saver = tf.train.Saver()
self.saver.restore(self.sess, tf.train.latest_checkpoint('./' + self.id))
else:
self.init_net()
self.sess.run(tf.global_variables_initializer())
self.saver = tf.train.Saver()
def train(self):
print 'traning'
with self.sess.as_default():
for i in range(1000):
x, y = generate_data(1000,self.id)
loss,_ = self.sess.run([self.loss,self.train_op],feed_dict={self.input_x:x,self.input_y:y})
x_test,y_test = generate_data(100,self.id)
prediction = self.sess.run(self.prediction, feed_dict={self.input_x: x_test, self.input_y: y_test})
acc = self.get_acc(prediction,y_test)
print 'step:',i,'loss:',loss,'acc:',acc
self.saver.save(self.sess, './' + self.id + '/model')
def test(self):
print 'testing'
with self.sess.as_default():
x_test, y_test = generate_data(1000,self.id)
prediction = self.sess.run(self.prediction, feed_dict={self.input_x: x_test, self.input_y: y_test})
acc = self.get_acc(prediction, y_test)
print 'acc:', acc