机器学习模型交叉验证脚本

机器学习模型交叉验证脚本

本文以阿里云机器学习平台上的 ps_smart (GBDT)算法为例,提供一个搜索最佳超参数的交叉验证任务的bash脚本。

机器学习模型超参数网格搜索脚本 提供了超参数网格搜索的能力。然而,当验证集的数量较少时,网格搜索的最优超参数非常容易过拟合,在实际的生产环境中,往往效果不如预期。为了缓解数据量少的问题,我们把网格搜索的Top N最优超参数保存下来,对这组超参数继续使用交叉验证的方式评估每组超参数对应的模型的实现效果指标。

本文提供的示例是一个LTV预测的回归任务,计算MAE、RMSE、WAPE 三个评估指标。

#!/bin/bash
#set -x
odps='.odpscmd/bin/odpscmd --config=odps_config.ini'
hyper_params_file='hyper_params.txt'

function log_info()
{
    if [ "$LOG_LEVEL" != "WARN" ] && [ "$LOG_LEVEL" != "ERROR" ]
    then
        echo "`date +"%Y-%m-%d %H:%M:%S"` [INFO] ($$)($USER): $*";
    fi
}

function prepare()
{
    log_info "function [$FUNCNAME] begin"
    if [ ! -d ".odpscmd" ]; then
        wget https://odps-repo.oss-cn-hangzhou.aliyuncs.com/odpscmd/latest/odpscmd_public.zip
        unzip -d .odpscmd odpscmd_public.zip
    fi
    log_info "function [$FUNCNAME] end"
}

function gen_partition() {
    log_info "function [$FUNCNAME] begin"
    local n=$1
    local k=$2
    local i
    pt=""
    for ((i=0;i<$n;i++))
    do
        if [ "$i" -eq "$k" ]; then
            continue
        fi
        pt=${pt}",'"${i}"'"
    done
    exclude_pt=${pt#,}
    log_info "function [$FUNCNAME] end"
}

function prepare_cv_data() {
    log_info "function [$FUNCNAME] begin"
    $odps -e "CREATE TABLE IF NOT EXISTS ps_smart_ltv
    (
        mae DOUBLE,
        rmse DOUBLE,
        wape DOUBLE
    )
    PARTITIONED BY (pt STRING COMMENT '实验参数', k STRING);"


    $odps -e "CREATE TABLE IF NOT EXISTS userfeature_v2_googleplay_mergekv_freedom_day3_dataset
    (
        dt  STRING,
        uid STRING,
        kv  STRING,
        targetprice DOUBLE,
        ispay BIGINT
    )
    COMMENT '训练数据集'
    PARTITIONED BY (pt STRING COMMENT '分区')
    LIFECYCLE 7;"


    local n=10
    $odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt)
    SELECT *
    FROM (
        SELECT dt,uid,kv,targetprice,ispay, FLOOR(rand() * ${n}) as pt
        FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_train_20220905_jp_m1
        UNION ALL
        SELECT dt,uid,replace(kv,',',' ') kv,targetprice,ispay, FLOOR(rand(20220826) * ${n}) as pt
        FROM rg_ai_bj.tmp_userfeature_v2_googleplay_mergekv_freedom_day3_test_20220905_jp_m1
    ) T;"


    local k
    for ((k=0;k<${n};k++))
    do
    {
        gen_partition $n $k
        $odps -e "INSERT OVERWRITE TABLE userfeature_v2_googleplay_mergekv_freedom_day3_dataset PARTITION(pt='exclude_${k}')
        SELECT \`(pt)?+.+\`
        FROM userfeature_v2_googleplay_mergekv_freedom_day3_dataset
        WHERE pt IN (${exclude_pt});"

    } &
    done
    wait
    log_info "function [$FUNCNAME] end"
}

function run_job() {
    log_info "function [$FUNCNAME] begin"
    local k_fold=$1
    local tree_count=$2
    local max_depth=$3
    local l1=$4
    local l2=$5
    local lr=$6
    local eps=$7
    local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
    log_info "run model: $model, k_fold: ${k_fold}"

    $odps -e "PAI -name ps_smart
    -project algo_public
    -DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
    -DinputTablePartitions='pt=exclude_${k_fold}'
    -DmodelName='smart_${k_fold}_${model}'
    -DoutputTableName='smart_table_${k_fold}_${model}'
    -DoutputImportanceTableName='smart_imp_${k_fold}_${model}'
    -DlabelColName='targetprice'
    -DfeatureColNames='kv'
    -DenableSparse='true'
    -Dobjective='reg:tweedie'
    -Dmetric='tweedie-nloglik'
    -DfeatureImportanceType='gain'
    -DtreeCount='${tree_count}'
    -DmaxDepth='${max_depth}'
    -Dshrinkage='${lr}'
    -Dl2='${l2}'
    -Dl1='${l1}'
    -Dlifecycle='31'
    -DsketchEps='${eps}'
    -DsampleRatio='1.0'
    -DfeatureRatio='1.0'
    -DbaseScore='0.0'
    -DminSplitLoss='0'
    "

    if [ $? -ne 0 ]; then
        return $?
    fi

    $odps -e "drop table if exists smart_output_${k_fold}_${model};"
    $odps -e "PAI -name prediction
    -project algo_public
    -DinputTableName='userfeature_v2_googleplay_mergekv_freedom_day3_dataset'
    -DinputTablePartitions='pt=${k_fold}'
    -DmodelName='smart_${k_fold}_${model}'
    -DoutputTableName='smart_output_${k_fold}_${model}'
    -DfeatureColNames='kv'
    -DappendColNames='targetprice'
    -DenableSparse='true'
    -DitemDelimiter=' '
    -Dlifecycle='128'
    "

    if [ $? -ne 0 ]; then
        return $?
    fi
    
    $odps -e "INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='${k_fold}')
    SELECT AVG(ABS(targetprice-prediction_result)) MAE,
        SQRT(AVG((targetprice-prediction_result)*(targetprice-prediction_result))) RMSE,
        SUM(ABS(targetprice-prediction_result))/SUM(ABS(targetprice)) WAPE
    FROM smart_output_${k_fold}_${model};"

    log_info "function [$FUNCNAME] end"
}


function run_cross_validation()
{
    log_info "function [$FUNCNAME] begin"
    local args=$@
    local tree_count=$1
    local max_depth=$2
    local l1=$3
    local l2=$4
    local lr=$5
    local eps=$6
    local model=${tree_count}_${max_depth}_${l1/0./p}_${l2/0./p}_${lr/0./p}_${eps/0./p}
 
    local n=10
    local i 
    for ((i=0;i<$n;i++))
    do
    {
        run_job ${i} $args  
    } &
    done
    wait


    $odps -e "
    INSERT OVERWRITE TABLE ps_smart_ltv PARTITION(pt='${model}', k='mean')
    select avg(MAE), avg(RMSE), avg(WAPE)
    from ps_smart_ltv
    where pt='${model}' and k!='mean';
    "

    log_info "function [$FUNCNAME] end"
}

function run_from_file()
{
    log_info "function [$FUNCNAME] begin"
    threadTask=1 #并发数
    fifoFile="test_fifo"
    rm -f ${fifoFile}
    mkfifo ${fifoFile}  #创建fifo管道
    exec 9<> ${fifoFile}
    rm -f ${fifoFile}
    # 预先向管道写入数据
    for ((i=0;i<${threadTask};i++))
    do
        echo "" >&9
    done
    
    log_info "wait all task finish,then exit!!!"
    while read line
    do
        read -u9
        {
            run_cross_validation $line
            echo "" >&9
        } &
    done < $1
    wait

    exec 9<&-  # 关闭文件描述符的读
    exec 9>&-  # 关闭文件描述符的写
    log_info "function [$FUNCNAME] end"
}

prepare
prepare_cv_data
run_from_file ${hyper_params_file}
#run_from_file $1

备注:请结合机器学习模型超参数网格搜索脚本使用,网格搜索的Top N最优超参数需要预先保存到hyper_params.txt文件中。

本文由 mdnice 多平台发布

你可能感兴趣的:(机器学习)