Tensorflow 测试RL算法,保存模型 并 读取进行测试

保存模型

RL中,我们一般都把一个网络结构写在一个类里面,保存的时候也是,可以如下写一个 save_net 函数:

def save_net(self):
    saver = tf.train.Saver()
    save_path = saver.save(self.sess, "./dqn/model/file_name.ckpt")
    print("Save to path: ", save_path)

在RL算法进行完N轮的训练之后,调用该函数进行模型保存:agent.save_net()
可以看到,会在model文件夹下多出四个文件:
Tensorflow 测试RL算法,保存模型 并 读取进行测试_第1张图片
也可以输出保存前的参数,进行观察,以便确认读取模型时是否成功读取了参数:

w1 = tf.get_default_graph().get_tensor_by_name('eval_net/l1/w1:0')  # 获得variable对应的Tensor
print(self.sess.run(w1))  # run一下这个Tensor得到结果

读取模型

首先注意,读取模型用于测试时,我们需要保证用到的变量和训练时的是一样的,比如测试DQN模型的效果:

class Test4DQN:
    def __init__(self):
        self.sess = tf.Session()
        self._build_net()

    def _build_net(self):
        # 测试时,只需要建立 evaluate_net,用来选择动作
        self.s = tf.placeholder(tf.float32, [None, 11])

        with tf.variable_scope('eval_net'):
            with tf.variable_scope('l1'):
                w1 = tf.Variable(np.arange(110).reshape((11, 10)), dtype=tf.float32, name="w1")
                b1 = tf.Variable(np.arange(10).reshape((1, 10)), dtype=tf.float32, name="b1")
                l1 = tf.nn.relu(tf.matmul(self.s, w1) + b1)

            with tf.variable_scope('l2'):
                w2 = tf.Variable(np.arange(240).reshape((10, 24)), dtype=tf.float32, name="w2")
                b2 = tf.Variable(np.arange(24).reshape((1, 24)), dtype=tf.float32, name="b2")
                self.q_eval = tf.matmul(l1, w2) + b2

        # 读取模型参数
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()
        self.sess.run(init)
        saver.restore(self.sess, "./xxxxx/model/file_name.ckpt")
        print(self.sess.run(w1))  # 可以再次输出,和我们保存时的输出结果进行对比,保证正确读取
                    
                    
    def choose_action(self, observation):
        observation = observation[np.newaxis, :]
        actions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})
        action = np.argmax(actions_value)
        return action

总结一下,就是先初始化一下测试框架中定义的变量(注意层级和名称需要对应,原来叫’‘w1’‘现在也要叫’‘w1’’),然后调用saver.restore(self.sess, "./xxxxx/model/file_name.ckpt"),即可将保存的网络参数赋值给现在的网络。

之后,和原来RL的流程一样,只是不再需要保存记忆和训练而已,最后可以得到测试的效果。

你可能感兴趣的:(BUG调试,程序人生)