最近在研究关于强化学习的部分工作,首先从OpenAI的Baseline中的小型GAIL算法出发。
首先参考了大神的文章从《西部世界》到GAIL(Generative Adversarial Imitation Learning)算法。
原文链接:https://blog.csdn.net/jinzhuojun/article/details/85220327#commentBox
对大神写的文章做一些补充和细节解释。
在baseline 的文件夹中运行即可以进行模型的训练
python3 -m baselines.gail.run_mujoco
在run_mujoco.py代码中写到
parser.add_argument('--task', type=str, choices=['train', 'evaluate', 'sample'], default='train')
可以在命令行后面添加 --task 改变任务为train 和evaluate。evaluate后面要加上存储的模型的地址
# 假设训练模型放在/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/
python3 -m baselines.gail.run_mujoco --task=evaluate --load_model_path=/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0
在baseline 中使用tensorflow方式存储模型: 在trpo_mpi.py 232行。
# Save model
if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
fname = os.path.join(ckpt_dir, task_name)
#U.save_variables(fname)
#print("the save path is ",fname)
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(tf.get_default_session(), fname)
所以在checkpoint中存储了可以用tensorflow方式读取模型的三个文件,而在运行评估模型时读取模型的方式采用的是baseline 中common自己定义的 U.load_variables(load_model_path)来读取文件,读取文件的类型是上面由tensorflow生成的文件的集合体。
U.load_variables(load_model_path)
因此在存储模型的时候也应该采用common中的定义的save_variables来存储模型生成集成文件:
if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
fname = os.path.join(ckpt_dir, task_name)
U.save_variables(fname)
print("the save path is ",fname)
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(tf.get_default_session(), fname)
然后运行train的命令行,在训练100次迭代之后就可以在保存模型的文件夹中发现一个无.data/.index/.meta后缀的集成文件。
此时再运行evaluate命令行就可以出现对模型的评估返回数据
在run_mujoco.py中的traj_1_generator函数中的while函数中插入env.render()就可以渲染出模型可视化结果。
while True:
ac, vpred = pi.act(stochastic, ob)
obs.append(ob)
news.append(new)
acs.append(ac)
ob, rew, new, _ = env.step(ac)
rews.append(rew)
env.render()
cur_ep_ret += rew
cur_ep_len += 1
if new or t >= horizon:
break
t += 1
感谢大佬的分享,同时在遇到困难的时候还是要敢于挑战权威呀。