使用tune.run怎么获取model summary的信息

使用rllib的时候,如果使用PPOTrainer或者某一个其他的trainer,在执行trainer.train()的时候,会打印model summary,也可以显式调用获取model summary的API,就像下面这样

>>> from ray.rllib.agents.ppo import PPOTrainer
>>> trainer = PPOTrainer(env="CartPole-v0", config={
     "eager": True, "num_workers": 0})
>>> policy = trainer.get_policy()
>>> policy.model.base_model.summary()
Model: "model"
_____________________________________________________________________
Layer (type)               Output Shape  Param #  Connected to
=====================================================================
observations (InputLayer)  [(None, 4)]   0
_____________________________________________________________________
fc_1 (Dense)               (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_value_1 (Dense)         (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_2 (Dense)               (None, 256)   65792    fc_1[0][0]
_____________________________________________________________________
fc_value_2 (Dense)         (None, 256)   65792    fc_value_1[0][0]
_____________________________________________________________________
fc_out (Dense)             (None, 2)     514      fc_2[0][0]
_____________________________________________________________________
value_out (Dense)          (None, 1)     257      fc_value_2[0][0]
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0

但是,上面的程序正常运行有两个前提:
1)必须是使用的tf2或者tf1的后端。因为base_model这个属性只在tf的policy中定义了。如果使用pytorch后端,那该怎么做?
2)必须是使用trainer,使用tune.run的话是不行的。


解决一:如果使用pytorch后端

from ray.rllib.agents.ppo import PPOTrainer
trainer=PPOTrainer(env='CartPole-v0', config={
     "num_workers": 0, "framework": "torch"}
print(trainer.get_policy().model)

解决二:使用tune,使用tune的话有两种方法解决,
法一:新添加一个trainer, 使用trainer调用

from ray.rllib.agents.ppo import PPOTrainer
from ray import tune

config={
     
	env='CartPole-v0',
	"num_workers": 0, 
	"framework": "torch"
}

config_for_trainer={
     
	"num_workers": 0, 
	"framework": "torch"
}

# 两个config的区别就是env

trainer = PPOTrainer(env= 'CartPole-v0', config=config_for_trainer)
print(trainer.get_policy().model)

results = tune.run(config=config, verbose=1)

法二:修改rllib的源代码,修改的文件为 ray/rllib/policy/torch_policy.py,修改行数为 164行,在第164行添加

print(self.model)

在这里插入图片描述


解决二也是既使用tune同时使用pytorch后端的解决方法。在训练时,优先是使用tune的,它集成了更方便的调参功能。


参考:https://discuss.ray.io/t/how-to-get-model-summary-using-pytorch-backend/2064

你可能感兴趣的:(#,Ray(RLlib),rllib,tune,model,summary)