最近很多人问到sac的时间为什么这么慢?
为什么我做ddpg的时候很快?
我是从tf1开始做,到现在开始整torch。
我用小数位数来表示值的大小,可以形成柱状图的效果!
直观的感受,影响强化学习算法时间消耗,主要有下面几个因素:
基本上就这么多,下面我贴一下我测试的环境信息:
cpu cores : 4
model name : Intel® Core™ i5-6500 CPU @ 3.20GHz
tensorflow-gpu1.10 ;
pytorch1.7.1-cpu,;
pytorch1.2.0-gpu-cpu.
其中pytorch1.7.1-GPU由于我的cuda驱动版本太低,无法支持,所以没法测试。
具体有SAC, TD3, DDPG三个主流的Off-policy算法;
每个配置训练一个随机种子,每组实验,20个epochs,每个epochs50个cycles,每个cycles50steps,这个是gym-fetch系列任务的标准训练流程;
默认为GPU,小数位数来表示值的大小,形成柱状图的效果。
可以看出来:
----------------------------------------------------
| exp_name | EpochTime |
----------------------------------------------------
| HER_DDPG | 15.5143217330022 |
| HER_SAC_AUTO | 18.371087785469598 |
| HER_SAC | 18.5900878813823205 |
| HER_TD3 | 15.2241159357852 |
| CPU_HER_DDPG | 21.9963255758646987204 |
| CPU_HER_SAC_AUTO | 23.801520339791391478457 |
| CPU_HER_SAC | 24.3463643059630570064655 |
| CPU_HER_TD3 | 19.71410898228858599 |
默认是GPU-pytorch1.2.0版本,CPU有1.71和1.2.0;
结果可以清晰的看到:
------------------------------------------------------------------------------------------------------------------------
| exp_name | EpochTime |
------------------------------------------------------------------------------------------------------------------------
| CPU_torch120_HER_DDPGTorch | 57.237169690294493307192169595509767532348632812500000000 |
| CPU_torch120_HER_SACTorch | 82.7758810982368657960250857286155223846435546875000000000000000000000000000000000 |
| CPU_torch120_HER_TD3Torch | 72.287067418858157452632440254092216491699218750000000000000000000000000 |
| CPU_torch171_HER_DDPGTorch | 37.15821849335783610968064749613404274 |
| CPU_torch171_HER_SACTorch | 54.333070903649378635691391536965966224670410156250000 |
| CPU_torch171_HER_TD3Torch | 48.6389267101949016591788677033036947250366210938 |
| HER_DDPGTorch | 16.45956843656848 |
| HER_SACTorch | 22.50593830983252274791 |
| HER_TD3Torch | 19.76218916779679802 |
结果分析:
torch1.2.0的运行速度并不比tf1.14.0的慢上多少~
这个有可能是因为我的update的次数太少了,即大部分时间是用于环境交互,少部分用于更新网络;
-------------------------------------------------
| exp_name | EpochTime |
-------------------------------------------------
| HER_DDPG | 15.5143217330022 |
| HER_SAC_AUTO | 18.371087785469598 |
| HER_SAC | 18.5900878813823205 |
| HER_TD3 | 15.2241159357852 |
| HER_DDPGTorch | 16.45956843656848 |
| HER_SACTorch | 22.50593830983252274791 |
| HER_TD3Torch | 19.76218916779679802 |
打印了计算梯度+梯度更新的时间,可以看出来,普遍来讲,还是tf1更快一些。
下面是单次更新时间*10000倍。
---------------------------------------------------------------------------------------------
| exp_name | update_time |
---------------------------------------------------------------------------------------------
| GPU_tf1_HER_DDPG | 31.18902600000000191471372090746 |
| GPU_tf1_HER_SAC_AUTO | 44.914377000000001771695679053664207458496094 |
| GPU_tf1_HER_SAC | 45.1578286000000090893991000484675168991088867 |
| GPU_tf1_HER_TD3 | 28.69846519999999401306922664 |
| GPU_torch120_HER_DDPGTorch | 28.55126619999999348920027842 |
| GPU_torch120_HER_SACTorch | 55.2312286000000085550709627568721771240234375000000000 |
| GPU_torch120_HER_TD3Torch | 43.08850280000000054769770940765738487243652 |
---------------------------------------------------------------------------------------------
| exp_name | EpochTime |
---------------------------------------------------------------------------------------------
| HER_DDPG | 15.5232898801458 |
| HER_SAC_AUTO | 18.387668513653189 |
| HER_SAC | 18.6041103305503768 |
| HER_TD3 | 15.224283310430 |
| GPU_tf1_HER_DDPG | 15.9738348579407 |
| GPU_tf1_HER_DDPG_PER | 31.40584538380304735483150579967 |
| GPU_tf1_HER_SAC_AUTO | 18.9588948599497478 |
| GPU_tf1_HER_SAC_AUTO_PER | 33.8790638621648128037122660316527 |
| GPU_tf1_HER_SAC | 18.9937221026420602 |
| GPU_tf1_HER_SAC_PER | 34.05487372636795129210440791212022 |
| GPU_tf1_HER_TD3 | 15.5398781498273 |
| GPU_tf1_HER_TD3_PER | 30.6231472516059923805187281687 |
| GPU_torch120_HER_DDPGTorch | 16.7084743475914 |
| GPU_torch120_HER_SACTorch | 22.02115873972574888739 |
| GPU_torch120_HER_TD3Torch | 19.93665778557459589 |
| HER_DDPGTorch | 16.46546708807171 |
| HER_SACTorch | 22.52556498200615209271 |
| HER_TD3Torch | 19.74400094596662925 |
| GPU_torch120_HER_DDPGTorch_PER | 45.8728787875175498811586294323205947875976562 |
| GPU_torch120_HER_SACTorch_PER | 51.086336704095202776443329639732837677001953125000 |
| GPU_torch120_HER_TD3Torch_PER | 48.2529249827067090450327668804675340652465820312 |
DDPG:PER/Without PER:
一倍以上的时间消耗
---------------------------------------------------------------------------------------
| exp_name | EpochTime |
---------------------------------------------------------------------------------------
| HER_DDPG | 15.5232898801458 |
| GPU_tf1_HER_DDPG | 15.9738348579407 |
| GPU_tf1_HER_DDPG_PER | 31.40584538380304735483150579967 |
| GPU_torch120_HER_DDPGTorch | 16.7084743475914 |
| HER_DDPGTorch | 16.46546708807171 |
| GPU_torch120_HER_DDPGTorch_PER | 45.8728787875175498811586294323205947875976562 |
TD3:PER/Without PER:
也是一倍以上的时间消耗
-----------------------------------------------------------------------------------------
| exp_name | EpochTime |
-----------------------------------------------------------------------------------------
| HER_TD3 | 15.224283310430 |
| GPU_tf1_HER_TD3 | 15.5398781498273 |
| GPU_tf1_HER_TD3_PER | 30.6231472516059923805187281687 |
| GPU_torch120_HER_TD3Torch | 19.93665778557459589 |
| HER_TD3Torch | 19.74400094596662925 |
| GPU_torch120_HER_TD3Torch_PER | 48.2529249827067090450327668804675340652465820312 |
SAC:PER/Without PER:
也是一倍以上的时间消耗
--------------------------------------------------------------------------------------------
| exp_name | EpochTime |
--------------------------------------------------------------------------------------------
| HER_SAC_AUTO | 18.387668513653189 |
| HER_SAC | 18.6041103305503768 |
| GPU_tf1_HER_SAC_AUTO | 18.9588948599497478 |
| GPU_tf1_HER_SAC_AUTO_PER | 33.8790638621648128037122660316527 |
| GPU_tf1_HER_SAC | 18.9937221026420602 |
| GPU_tf1_HER_SAC_PER | 34.05487372636795129210440791212022 |
| GPU_torch120_HER_SACTorch | 22.02115873972574888739 |
| HER_SACTorch | 22.52556498200615209271 |
| GPU_torch120_HER_SACTorch_PER | 51.086336704095202776443329639732837677001953125000 |
https://github.com/kaixindelele/DRLib/blob/main/spinup_utils/log2table.py
之前测试过torch-gpu要比tf1-gpu要慢很多;
有两个原因:
一个是那次的torch版本是0.4;
一个是那个的更新频次要高很多~;
以后看来还是老老实实用torch算了~。
先放这儿,明天继续更新~
ps: 欢迎做强化的同学加群一起学习:
深度强化学习-DRL:799378128
欢迎关注知乎帐号:未入门的炼丹学徒
CSDN帐号:https://blog.csdn.net/hehedadaq
极简spinup+HER+PER代码实现:https://github.com/kaixindelele/DRLib