1. 注册可训练的函数或类
ray.tune.register_trainable(name, trainable)
参数:
name (str) - 注册的方法或函数名。
trainable (obj) - 函数或tune.Trainable类。函数必须采用(config, status_reporter)作为参数,并且在注册的过程中自动转换为类。
2. 构造experiment对象
ray.tune.Experiment(name, run, stop=None, config=None, trial_resources=None, repeat=1, local_dir=None, upload_dir='', checkpoint_freq=0, max_failures=3)
参数:
name (str) – 名字。
run (str) – 要训练的算法或模型。 这可以指内置算法的名称(例如RLLib的DQN或PPO),或者在tune注册表中注册的用户定义的可训练函数或类。
stop (dict) - 停止标准。 值可以是TrainingResult中的任何字段,以先到达者为准。 默认为空字典。
config (dict) – 特定于算法的配置(例如env,hyperparams)。 默认为空字典。
trial_resources (dict) – 每次试验分配的机器资源,例如: {"cpu":64,"gpu":8}。 请注意,除非您在此处指定GPU,否则不会分配GPU。 默认为1个CPU和0个GPU。
repeat (int) – 重复每次试验的次数。 默认为1。
local_dir (str) – 将训练结果保存到的本地目录。 默认为〜/ ray_results。
upload_dir (str) – 同步训练结果的可选URI地址(例如s3:// bucket)。
checkpoint_freq (int) – 设置检查点间的训练迭代次数。 值0(默认值)禁用设置检查点。
max_failures (int) – 设置尝试从最后一个检查点恢复试验的最多次数。 仅在启用了检查点时适用。 默认为3。
3. 运行实验程序
ray.tune.run_experiments(experiments, scheduler=None, with_server=False, server_port=4321, verbose=True, queue_trials=False)
参数:
experiments (Experiment | list | dict) - 要运行的实验。
scheduler (TrialScheduler) - 用于执行实验的调度程序。在FIFO(默认),MedianStopping,AsyncHyperBand,HyperBand或HyperOpt中进行选择。
with_server (bool) - 启动后台Tune服务器。 使用客户端API需要。
server_port (int) - 启动TuneServer的端口号。
verbose (bool) - 每次试验应打印多少输出。
queue_trials (bool) - 当群集当前没有足够的资源来启动试验时,是否对试验进行排队。 在自动扩展群集上运行时,应将其设置为True以启用自动向上扩展。
返回值:
4. 调度程序HyperOptScheduler
ray.tune.hpo_scheduler.HyperOptScheduler(max_concurrent=None, reward_attr='episode_reward_mean')
参数:
reward_attr (str) – TrainingResult目标值属性。 这是指一个递增的值,在与HyperOpt交互时在内部被否定,以便HyperOpt可以“最大化”该值。
5. 调度程序AsyncHyperBandScheduler
ray.tune.async_hyperband.AsyncHyperBandScheduler(time_attr='training_iteration', reward_attr='episode_reward_mean', max_t=100, grace_period=10, reduction_factor=3, brackets=3)
参数:
time_attr (str) – 用于比较时间的TrainingResult 。请注意,你可以传递非时间性的东西,例如training_iteration作为进度的度量,唯一的要求是属性应该单调增加。
reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。 终止进程时将使用此属性。
max_t (float) – 每次试验的最大时间单位。 经过max_t时间单位(由time_attr确定)后,试验将停止。
grace_period (float) – 终止试验的条件。单位与time_attr命名的属性相同。
reduction_factor (float) – 用于设置减半的速率和数量。
brackets (int) – bracket的数量。 每个bracket具有不同的减半率,由减少系数指定。
6. 调度程序HyperBandScheduler
ray.tune.hyperband.HyperBandScheduler(time_attr='training_iteration', reward_attr='episode_reward_mean', max_t=81)
参数:
reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。 终止进程时将使用此属性。
max_t (int) – 每次试验的最大时间单位。 经过max_t时间单位(由time_attr确定)后,试验将停止。该调度器将在时间结束后终止试验。
7. 调度程序MedianStoppingRule
ray.tune.median_stopping_rule.MedianStoppingRule(time_attr='time_total_s', reward_attr='episode_reward_mean', grace_period=60.0, min_samples_required=3, hard_stop=True, verbose=True)
参数:
reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。
grace_period (float) – 终止试验的条件。单位与time_attr命名的属性相同。
min_samples_required (int) – 最小样本计算中位数。
hard_stop (bool) – 如果为False,则暂停试验而不是停止试验。 当所有其他试验完成后,暂停试验将恢复并允许以FIFO运行。
verbose (bool) – 如果为True,则每次试验报告时都会输出中位数和最佳 结果。 默认为True。
8. 可训练模型,函数等的抽象类
ray.tune.trainable.Trainable(config=None, logger_creator=None)
参数:
logdir(str) - 放置训练输出的目录。
重载方法:
_save(checkpoint_dir) - 重载该方法实现save():调用save()将可训练的训练状态保存到磁盘,并且restore(path)应该将训练状态恢复到给定状态。
_restore(checkpoint_path) - 重载该方法实现restore()。
_setup() - 重载该方法实现自定义初始化。
_stop() - 重载该方法实现清理和关闭程序。
注:
1)通常,在继承Trainable时,你只需要在这里实现_train,_save和_restore。
2)如果你不需要checkpoint/restore,那么你也可以通过提供my_train(config,reporter)函数并调用以下内容来实现,而不是实现此类。
register_trainable(“my_func”,train)
注册它以便与Tune一起使用。该功能将自动转换为该接口(无检查点功能)。
9. 实现客户端与正在进行的Tune实验进行交互,需要服务器已开始运行
ray.tune.web_server.TuneClient(tune_address)
方法:
get_trial(trial_id) - 返回查询试验的最后结果。
add_trial(name, trial_spec) - 给相应名字的试验添加配置。
stop_trial(trial_id) - 关闭相应id的试验。
10. PopulationBasedTraining(PBT)算法
ray.tune.pbt.PopulationBasedTraining(time_attr='time_total_s', reward_attr='episode_reward_mean', perturbation_interval=60.0, hyperparam_mutations={}, resample_probability=0.25, custom_explore_fn=None)
参数:
reward_attr (str) – TrainingResult目标值属性。 与time_attr一样,这可以指任何客观值。 终止进程时将使用此属性。
perturbation_interval (float) – 模型将在time_attr的这个时间间隔内考虑扰动。 请注意,扰动会导致检查点开销,因此你不应将此设置为过于频繁。
hyperparam_mutations (dict) – Hyperparams变异。 格式如下:对于每个键,可以提供列表或函数。 列表指定一组允许的分类值。 函数指定连续参数的分布。 你必须至少指定hyperparam_mutations或custom_explore_fn中的一个。
resample_probability (float) – 应用hyperparam_mutations时从原始分布重新采样的概率。 如果不重新采样,如果连续,则值将被因子1.2或0.8扰动,或者如果离散则将值更改为相邻值。
custom_explore_fn (func) – 你还可以指定自定义探索功能。 在应用来自hyperparam_mutations的内置扰动后,此函数被调用为f(config),并应根据需要返回更新的配置。 你必须至少指定hyperparam_mutations或custom_explore_fn中的一个。