(一)Ray:Tune整定模型超参数

Ray Tune 模块

Tune

Tune是一个超参数整定模块,他以’trials’来构建起每一次尝试。为’trials’利用Scheduler作为调度器。可以使用包括PBT,AsyncHyperBand在内的多种超参数整定方法。

如何使用?

根据上述所述,分为以下几步:

  1. 根据自己的需求构建一个trials,可理解为一个训练epoch,该trials需继承Tune.Trainable类
  2. 选择合适的Schedulers
  3. 调用ray.tune.run(),其中trials作为run的run_or_experiment 传入

例子:mnist-pytorch模型训练超参数整定

借助mnist-pytorch官方例子进行解释

1.构建一个trial

在本例中,一个trial具有以下几步

  1. 训练一轮
  2. 评估此轮模型效果
  3. 返回评估指标
class TrainMNIST(tune.Trainable):
    def _setup(self, config):
        # 类似于__init__函数,用于初始化相关配置
        # 1.读数据:self.data_loader = ... 
        # 2.构建模型 : self.model = ...
        # 3.优化器: self.optimizer = ...
        #... 具体源码见官方教程

    def _train(self):
        #训练模型
        train(
            self.model, self.optimizer, self.train_loader, device=self.device)
        acc = test(self.model, self.test_loader, self.device)
        return {
     "mean_accuracy": acc}

    def _save(self, checkpoint_dir):
        #存模型
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return checkpoint_path

    def _restore(self, checkpoint_path):
        #读模型
        self.model.load_state_dict(torch.load(checkpoint_path))

其中traintest分别实现了模型一次训练和测试,具体参见源码吧

一般来说实现一个trial均需要实现以上四个方法

class TrainMNIST(tune.Trainable):
    def _setup(self, config):
        # 类似于__init__函数,用于初始化相关配置
        ...
    def _train(self):
        #训练模型
        # model.train() 
        # metric = model.eval()
        # return {'metric' : metric}
        ...

    def _save(self, checkpoint_dir):
        #存模型
        ...

    def _restore(self, checkpoint_path):
        #读模型
        ...

2.选择并初始化一个Schedulers

在本例中采用的是ASHAScheduler (论文PDF,论文Blog)

 sched = ASHAScheduler(metric="mean_accuracy")

3.run trial

    analysis = tune.run(
        TrainMNIST,
        scheduler=sched, 
        # 设置停止条件
        stop={
     
            "mean_accuracy": 0.95,
            ...
        },
        resources_per_trial={
     
            ...
        },
        
        num_samples=1 if args.smoke_test else 20,
        checkpoint_at_end=True,
        checkpoint_freq=3,

        #需要整定的参数
        config={
     
            "args": args,
            "lr": tune.uniform(0.001, 0.1),  
            "momentum": tune.uniform(0.1, 0.9),
        })

tune.run 输入以上定义的trial,并定义stop 及其他参数
其中需要整定的超参数以tune.uniform函数采样,当然还有其他采样超参数方法

分析数据

得到最优模型的超参数配置

 print("Best config is:", analysis.get_best_config(metric="mean_accuracy"))

结果

(一)Ray:Tune整定模型超参数_第1张图片
可以一次性看出所有参数的效果,选取最优的超参数。

参考
ray官方教程

你可能感兴趣的:(ray,强化学习,深度学习,python)