贝叶斯搜索最佳参数

model=XGBClassifier(learning_rate=0.96,
                   n_estimators=91,
                   max_depth=2,
                   min_child_weight=1,
                   objective='binary:logistic',
                   subsample=0.9,
                    feature_fraction=0.5,
#                     reg_alpha=8,
#                     reg_lambda=12,
#                    colsample_bytree=0.99,
#                    nthread=8,
#                    scale_pos_weight=1,
                   seed=1330)
def rf_cv(learning_rate,n_estimators, max_depth, min_child_weight,subsample,reg_alpha,reg_lambda):
    val = cross_val_score(
       XGBClassifier(
            
                            learning_rate=learning_rate,
                            n_estimators=int(n_estimators),
                            max_depth=int(max_depth),
                            min_child_weight=min_child_weight,
                            subsample=min(subsample, 0.999),
                           # num_leaves=int(num_leaves),
                            reg_alpha=int(reg_alpha),
                            reg_lambda=int(reg_lambda)
                            # feature_fraction=min(feature_fraction, 0.999),
                            # random_state=int(random_state)
        ),
        train_x, train_y, scoring='roc_auc', cv=5
    ).mean()
    return val
params= {
                'learning_rate': (0.1, 1.5),
                'n_estimators': (1, 100),
                'max_depth': (1, 100),
                'min_child_weight': (0.1, 0.999),
                'subsample': (0.1, 0.999),
                 #'num_leaves':(1,200),
                
               # 'feature_fraction': (0.1, 0.999),
                #'random_state': (2, 2500),
                'reg_alpha':(1,9),
                'reg_lambda':(5,20),
        
        
        }
#实例化贝叶斯优化函数
rf_bo = BayesianOptimization(rf_cv, params)
#最大化目标函数,贝叶斯中只能最大化目标函数,因此当需要求最小化目标函数时,需要将目标函数取相反数。
rf_bo.maximize(init_points=5,n_iter=50)
params=rf_bo.max
params=params["params"]
params#打印最优参数
val=rf_cv(**params)
val#打印交叉准确率

你可能感兴趣的:(python)