最近在做关于数据的一些东西,就是研究游戏在线人数的变换曲线,看了梯度下降。
梯度下降的介绍百度一下说的都很好,我也是百度学习的 可以@refer 这篇 http://blog.csdn.net/woxincd/article/details/7040944。
我的理解就是先随意找一个点,然后求出这个点在各个方向的切线,顺着这个方向认为是下降最快的方向,然后根据步长(自己设定的一个值,这里是0.1)调节走的速度。最终找到极值。关键在于迭代,关于迭代,牛顿方法的迭代会更迅速。
下面是我用php写的一个小demo。
预期多元一次方程是 ax1 + bx2 = y , 给定了N组 x 和对应y的值 ,求 a b 分别是多少。
先转换成求极值 => 误差最小。
$dataset = [[1,4],[2,5],[5,1],[4,2]]; //初始的三组x(每组x包括x1 x2)
$dataret = [19,26, 19, 20]; // 对应三组x 的y 值
$expect = [10, 10]; //随意找到的开始点 这里是指 预测 a = 10 b = 10
$step = 0.001; //步长
$times = 1000000;
/*
*梯度下降 *
* @auther menmei
* @date 2017/03/24
*
*/
/*
* 梯度下降求多元一次方程的多元参数
*
*
* @param 原始数据 Array
* @param 原始数据结果 Array
* @param 初始参数 Arrayθ
* @param 步长 double
* @param 循环次数 int
*
* @return 参数数组 Arrayθ
*
*/
function gradientDescent($dataset, $dataret, $expect, $step, $times){
//check given params
$setTotal = count($dataset);
$paramsTotal = count($dataset[0]);
if($setTotal < 2 || (count($expect) != $paramsTotal) || count($dataret) != $setTotal ) return False;
//$deviation = array_fill(0, $paramsTotal, 0);
for($i = 0; $i < $times; $i ++){
$h = 0;
$index = $i % $setTotal;
for($j = 0; $j < $paramsTotal; $j ++){
$h += ($expect[$j] * $dataset[$index][$j]);
}
$error = $h - $dataret[$index];
for($k = 0; $k < $paramsTotal; $k ++){
//这里是关键 这里 $error * $dataset[$index][$k] 是J(θ) 按梯度方向减少的量
//是对J(θ) 求偏导得到的 => 按梯度每个方向的斜率 * 步长
$expect[$k] -= $step * $error * $dataset[$index][$k];
}
//calculate new deviation
$deviation = 0 ;
for($l = 0; $l < $setTotal; $l ++){
$h = 0;
for($m = 0; $m < $paramsTotal; $m ++){
$h += ($expect[$m] * $dataset[$l][$m]);
}
$deviation += ($h - $dataret[$l]) * ($h - $dataret[$l]);
}
if($deviation < 0.001) break;
}
echo "误差是{$deviation}";
return $expect;
}
/***************** EXAMPLE ******************/
//sample 1
//这里步长设置成 0.1 就会越过最低点 然后继续向上。所以步长选择很 !重 !要 !
$dataset = [[1,4],[2,5],[5,1],[4,2]];
$dataret = [19,26, 19, 20];
$expect = [10, 10];
$step = 0.001;
$times = 1000000;
//sample 2
$dataset = [[1, 1, 2], [1, 2, 3], [1, 2, 5], [1, 8, 3], [1, 4, 7]];
$dataret = [13, 19, 27, 31, 39];
$expect = [0, 0, 0];
//sample 3
$dataset = [[1, 2], [2, 3], [2, 5], [8, 3], [4, 7]];
$dataret = [1, 4, 0, 34, 6];
$expect = [0,0];
$ret = gradientDescent($dataset, $dataret, $expect, $step, $times);
var_dump($ret);