Ray Tune是一个可扩展的超参数优化框架,用于强化学习和深度学习。 从在单台计算机上运行一个实验到使用高效搜索算法在大型集群上运行,而无需更改代码。
本篇博客中所提及的函数。
首先需要安装Ray,使用命令 pip install ray
简单示例:
import ray
import ray.tune as tune
ray.init()
tune.register_trainable("train_func", train_func)
all_trials = tune.run_experiments({
"my_experiment": {
"run": "train_func",
"stop": {"mean_accuracy": 99},
"config": {
"lr": tune.grid_search([0.2, 0.4, 0.6]),
"momentum": tune.grid_search([0.1, 0.2]),
}
}
})
对于想要调整的函数,添加两行修改(请注意,我们使用PyTorch作为示例,但Ray Tune适用于任何深度学习框架,PyTorch中文文档)
def train_func(config, reporter): # add a reporter arg
model = NeuralNet()
optimizer = torch.optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
dataset = ( ... )
for idx, (data, target) in enumerate(dataset):
# ...
output = model(data)
loss = F.MSELoss(output, target)
loss.backward()
optimizer.step()
accuracy = eval_accuracy(...)
reporter(timesteps_total=idx, mean_accuracy=accuracy) # report metrics
这个PyTorch脚本使用Ray Tune在train_func函数上运行一个小的网格搜索,在命令行上报告状态,直到达到mean_accuracy> = 99的停止条件:
== Status ==
Using FIFO scheduling algorithm.
Resources used: 4/8 CPUs, 0/0 GPUs
Result logdir: ~/ray_results/my_experiment
- train_func_0_lr=0.2,momentum=1: RUNNING [pid=6778], 209 s, 20604 ts, 7.29 acc
- train_func_1_lr=0.4,momentum=1: RUNNING [pid=6780], 208 s, 20522 ts, 53.1 acc
- train_func_2_lr=0.6,momentum=1: TERMINATED [pid=6789], 21 s, 2190 ts, 100 acc
- train_func_3_lr=0.2,momentum=2: RUNNING [pid=6791], 208 s, 41004 ts, 8.37 acc
- train_func_4_lr=0.4,momentum=2: RUNNING [pid=6800], 209 s, 41204 ts, 70.1 acc
- train_func_5_lr=0.6,momentum=2: TERMINATED [pid=6809], 10 s, 2164 ts, 100 acc
为了报告增量进度,train_func定期调用Ray Tune传入的报告函数,以返回当前时间步长和ray.tune.result.TrainingResult中定义的其他度量。 增量结果将同步到群集的头节点上的本地磁盘。
tune.run_experiments返回一个Trial对象列表,你可以通过trial.last_result检查结果。
Ray Tune有如下特点:
灵活的试验性变量生成,包括网格搜索,随机搜索和条件参数分布。
Ray Tune在集群中调度了许多trials。 每个trial都运行一个用户定义的Python函数或类,并通过传递用户代码的配置变量进行参数化。
要运行任何给定的函数,你需要运行register_trainable注册一个名称。 这使得所有Ray上的worker都意识到这一功能的存在。
ray.tune.register_trainable(name, trainable)
Ray Tune提供run_experiments函数,用于生成和运行实验规范描述的trials。 trials由实施搜索算法的试验调度程序安排和管理(默认为FIFO)。
ray.tune.run_experiments(experiments, scheduler=None, with_server=False, server_port=4321, verbose=True, queue_trials=False)
Ray Tune可以在Ray的任何地方使用,例如 在电脑上使用嵌入在Python脚本中的ray.init()或用于大规模并行的自动缩放集群。
具体示例参考。
默认情况下,Ray Tune使用FIFOScheduler类按顺序调度trials。 但是,你还可以指定自定义计划算法,该算法可以提前停止试验,扰动参数或合并来自外部服务的建议。 当前实施的试验调度器包括基于群体的训练(PBT),中值停止规则,基于模型的优化(HyperOpt)和HyperBand。
run_experiments({...}, scheduler=AsyncHyperBandScheduler())
经常需要在驱动程序上计算大对象(例如,训练数据,模型权重)并在每个trial中使用该对象。 Ray Tune提供了一个pin_in_object_store实用程序函数,可用于广播此类大对象。 以这种方式固定的对象在驱动程序进程运行时永远不会从Ray对象存储库中逐出,并且可以通过get_pinned_object从任何任务中有效地检索。
import ray
from ray.tune import register_trainable, run_experiments
from ray.tune.util import pin_in_object_store, get_pinned_object
import numpy as np
ray.init()
# X_id can be referenced in closures
X_id = pin_in_object_store(np.random.random(size=100000000))
def f(config, reporter):
X = get_pinned_object(X_id)
# use X
register_trainable("f", f)
run_experiments(...)
HyperOptScheduler是一个trial调度程序,由HyperOpt支持执行基于顺序模型的超参数优化。 要使用此调度程序,需要通过以下命令安装HyperOpt:
$ pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
一个示例。
注意:
HyperOptScheduler在奖励属性中采用了增加的度量标准。 如果试图最小化损失,请务必在函数/类报告中指定mean_loss,并在HyperOptScheduler初始化程序中指定reward_attr = neg_mean_loss。
要启用检查点,你必须实现一个Trainable类(可训练的函数不是可检查的,因为它们永远不会将控制权返回给它们的调用者)。 最简单的方法是子类化预定义的Trainable类并实现其_train,_save和_restore抽象方法(示例):需要实现此接口以支持调度程序(如HyperBand和PBT)中的资源多路复用。
对于TensorFlow模型训练,这看起来像这样(完整tensorflow示例):
class MyClass(Trainable):
def _setup(self):
self.saver = tf.train.Saver()
self.sess = ...
self.iteration = 0
def _train(self):
self.sess.run(...)
self.iteration += 1
def _save(self, checkpoint_dir):
return self.saver.save(
self.sess, checkpoint_dir + "/save",
global_step=self.iteration)
def _restore(self, path):
return self.saver.restore(self.sess, path)
另外,检查点可用于为实验提供容错。 设置checkpoint_freq:N和max_failures:M,即每N次迭代的试验设置checkpoint,每次试验最多M次崩溃就进行恢复,例如:
run_experiments({
"my_experiment": {
...
"checkpoint_freq": 10,
"max_failures": 5,
},
})
必须实现以下的类接口才能启用检查点:
class ray.tune.trainable.Trainable(config=None, logger_creator=None)
你可以使用Tune客户端API添加或删除试验来修改正在进行的实验。为此,请验证是否安装了请求库:
$ pip install requests
要使用客户端API,您可以使用with_server = True开始实验:
run_experiments({...}, with_server=True, server_port=4321)
然后,在客户端,您可以使用以下类。 服务器地址默认为localhost:4321。 如果在群集上,你可能希望转发此端口(例如ssh -L
class ray.tune.web_server.TuneClient(tune_address)
一个Client API示例。
你可以在此处找到使用Ray Tune及其各种功能的示例列表,包括使用Keras,TensorFlow和基于人口训练的示例。