run()方法详解链接
feed_dict={y:3}
把3赋值给y,当次计算有效。
在lstm和cnn进行train和predict
def run_step(self, sess, is_train, batch):
"""
:param sess: session to run the batch
:param is_train: a flag indicate if it is a train batch
:param batch: a dict containing batch data
:return: batch result, loss of the batch or logits
"""
feed_dict = self.create_feed_dict(is_train, batch)
if is_train:
global_step, loss, _ = sess.run(
[self.global_step, self.loss, self.train_op],
feed_dict)
return global_step, loss
else:
lengths, logits = sess.run([self.lengths, self.logits], feed_dict)
return lengths, logits
训练的时候图全部计算而预测的时候图不用计算损失值
def predict(self, sent):
"""
Evaluates model on a dev set
"""
x_list = [data_helpers.preprocess(sent)]
x = np.array(list(self.vocab_processor.fit_transform(x_list)))
x_batch = (x)
feed_dict = {
self.cnn.input_x: x_batch,
self.cnn.dropout_keep_prob: 1.0
}
prediction, softmax_scores = self.sess.run(
[self.cnn.predictions, self.cnn.softmax_scores],
feed_dict)
score = max(softmax_scores[0])
intent = label_dict.get(str(prediction[0]))
# print(intent)
return intent, score
def train_step(x_batch, y_batch):
"""
A single training step
"""
feed_dict = {
cnn.input_x: x_batch,
cnn.input_y: y_batch,
cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
}
_, step, summaries, loss, accuracy = sess.run(
[train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],
feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
train_summary_writer.add_summary(summaries, step)
import tensorflow as tf
y = tf.placeholder(dtype=tf.float32)
b = tf.Variable(1.0,dtype=tf.float32)
new_val = tf.add(y, b)
update = tf.assign(b, new_val)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for _ in range(3):
print(sess.run(update,feed_dict={y:3}))
# print(sess.run(y))
lstm和cnn训练的时候,就可以通过把b设置为变量,y设置为输入。就可以循环训练参数b。达到训练模型的目的。