一、梯度下降简介
http://baike.baidu.com/link?url=JRP2bhxuJzeawEEEXgqNRavoZBlGxbm4QW0EEoPz1epO4DFoTRnvssRAK0i2P95yMg8EBa_jf3qU_9u1JnQ1k_
梯度下降法是一个最优化算法。是求解无约束优化问题最简单和最古老的方法之一。
例子:
f(x)=x^2的最小值
利用梯度下降的方法解题步骤:
求梯度;grad=2x ===> 梯度是求导(多元就是求偏导),几何意义就是某点变化最快的方向求最小值,就是梯度的相反方向移动
x := x - step * grad ==> step为步长,就学习速率,如果步长足够下,保证了每一次迭代都在减小,但是收敛慢。如果步长太大,不能保证每一次迭代都减少,也不能保证收敛。可能在最优解出波动。
循环第2步.直到x的变化使得f(x)在两次的迭代之间的差值足够小。比如0.00000001,也就是说两次迭代f(x)基本没有变化。则说明f(x)已经达到了局部最小值。
package examples.mllib; import java.text.DecimalFormat; /** * @function ==> f(x) = x^2 * @result => figure out the smallest of f(x) * @learning_rate => 0.1 * @result_conditions => difference between every change is samller than * 0.00000001 * @author Harry Wu * */ public class JavaSGDExample { final static DecimalFormat df = new DecimalFormat("0.0000"); static double func(double x) { return Math.pow(x, 2); } static double grad(double x) { return 2 * x; } public static void main(String[] args) { double step = 0.1; double x = 2; int k = 0; double change = func(x); double current = func(x); final double returnThreshold = 0.00000001; while (change > returnThreshold) { x = x - step * grad(x); change = current - func(x); current = func(x); k += 1; } System.out.println(df.format(x)); } }
二、多元
http://www.lxway.com/562680451.htm
package examples.mllib; import java.util.ArrayList; /** * spark-1.5.1-bin-hadoop2.6/data/mllib/ridge-data/lpsa.data * * @author Harry Wu * */ public class JavaLinearRegressionWithSGDExample2 { static ArrayList<LabeledPoint> datas = new ArrayList<LabeledPoint>() { { add(new LabeledPoint(-0.4307829, new double[] { -1.63735562648104, -2.00621178480549, -1.86242597251066, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(-0.1625189, new double[] { -1.98898046126935, -0.722008756122123, -0.787896192088153, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(-0.1625189, new double[] { -1.57881887548545, -2.1887840293994, 1.36116336875686, -1.02470580167082, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.155348103855541 })); add(new LabeledPoint(-0.1625189, new double[] { -2.16691708463163, -0.807993896938655, -0.787896192088153, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(0.3715636, new double[] { -0.507874475300631, -0.458834049396776, -0.250631301876899, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(0.7654678, new double[] { -2.03612849966376, -0.933954647105133, -1.86242597251066, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(0.8544153, new double[] { -0.557312518810673, -0.208756571683607, -0.787896192088153, 0.990146852537193, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.2669476, new double[] { -0.929360463147704, -0.0578991819441687, 0.152317365781542, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.2669476, new double[] { -2.28833047634983, -0.0706369432557794, -0.116315079324086, 0.80409888772376, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.2669476, new double[] { 0.223498042876113, -1.41471935455355, -0.116315079324086, -1.02470580167082, -0.522940888712441, -0.29928234305568, 0.342627053981254, 0.199211097885341 })); add(new LabeledPoint(1.3480731, new double[] { 0.107785900236813, -1.47221551299731, 0.420949810887169, -1.02470580167082, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.687186906466865 })); add(new LabeledPoint(1.446919, new double[] { 0.162180092313795, -1.32557369901905, 0.286633588334355, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.4701758, new double[] { -1.49795329918548, -0.263601072284232, 0.823898478545609, 0.788388310173035, -0.522940888712441, -0.29928234305568, 0.342627053981254, 0.199211097885341 })); add(new LabeledPoint(1.4929041, new double[] { 0.796247055396743, 0.0476559407005752, 0.286633588334355, -1.02470580167082, -0.522940888712441, 0.394013435896129, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.5581446, new double[] { -1.62233848461465, -0.843294091975396, -3.07127197548598, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.5993876, new double[] { -0.990720665490831, 0.458513517212311, 0.823898478545609, 1.07379746308195, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.6389967, new double[] { -0.171901281967138, -0.489197399065355, -0.65357996953534, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.6956156, new double[] { -1.60758252338831, -0.590700340358265, -0.65357996953534, -0.619561070667254, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.7137979, new double[] { 0.366273918511144, -0.414014962912583, -0.116315079324086, 0.232904453212813, -0.522940888712441, 0.971228997418125, 0.342627053981254, 1.26288870310799 })); add(new LabeledPoint(1.8000583, new double[] { -0.710307384579833, 0.211731938156277, 0.152317365781542, -1.02470580167082, -0.522940888712441, -0.442797990776478, 0.342627053981254, 1.61744790484887 })); add(new LabeledPoint(1.8484548, new double[] { -0.262791728113881, -1.16708345615721, 0.420949810887169, 0.0846342590816532, -0.522940888712441, 0.163172393491611, 0.342627053981254, 1.97200710658975 })); add(new LabeledPoint(1.8946169, new double[] { 0.899043117369237, -0.590700340358265, 0.152317365781542, -1.02470580167082, -0.522940888712441, 1.28643254437683, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(1.9242487, new double[] { -0.903451690500615, 1.07659722048274, 0.152317365781542, 1.28380453408541, -0.522940888712441, -0.442797990776478, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.008214, new double[] { -0.0633337899773081, -1.38088970920094, 0.958214701098423, 0.80409888772376, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.0476928, new double[] { -1.15393789990757, -0.961853075398404, -0.116315079324086, -1.02470580167082, -0.522940888712441, -0.442797990776478, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.1575593, new double[] { 0.0620203721138446, 0.0657973885499142, 1.22684714620405, -0.468824786336838, -0.522940888712441, 1.31421001659859, 1.72741139715549, -0.332627704725983 })); add(new LabeledPoint(2.1916535, new double[] { -0.75731027755674, -2.92717970468456, 0.018001143228728, -1.02470580167082, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.332627704725983 })); add(new LabeledPoint(2.2137539, new double[] { 1.11226993252773, 1.06484916245061, 0.555266033439982, 0.877691038550889, 1.89254797819741, 1.43890404648442, 0.342627053981254, 0.376490698755783 })); add(new LabeledPoint(2.2772673, new double[] { -0.468768642850639, -1.43754788774533, -1.05652863719378, 0.576050411655607, -0.522940888712441, 0.0120483832567209, 0.342627053981254, -0.687186906466865 })); add(new LabeledPoint(2.2975726, new double[] { -0.618884859896728, -1.1366360750781, -0.519263746982526, -1.02470580167082, -0.522940888712441, -0.863171185425945, 3.11219574032972, 1.97200710658975 })); add(new LabeledPoint(2.3272777, new double[] { -0.651431999123483, 0.55329161145762, -0.250631301876899, 1.11210019001038, -0.522940888712441, -0.179808625688859, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.5217206, new double[] { 0.115499102435224, -0.512233676577595, 0.286633588334355, 1.13650173283446, -0.522940888712441, -0.179808625688859, 0.342627053981254, -0.155348103855541 })); add(new LabeledPoint(2.5533438, new double[] { 0.266341329949937, -0.551137885443386, -0.384947524429713, 0.354857790686005, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.332627704725983 })); add(new LabeledPoint(2.5687881, new double[] { 1.16902610257751, 0.855491905752846, 2.03274448152093, 1.22628985326088, 1.89254797819741, 2.02833774827712, 3.11219574032972, 2.68112551007152 })); add(new LabeledPoint(2.6567569, new double[] { -0.218972367124187, 0.851192298581141, 0.555266033439982, -1.02470580167082, -0.522940888712441, -0.863171185425945, 0.342627053981254, 0.908329501367106 })); add(new LabeledPoint(2.677591, new double[] { 0.263121415733908, 1.4142681068416, 0.018001143228728, 1.35980653053822, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.7180005, new double[] { -0.0704736333296423, 1.52000996595417, 0.286633588334355, 1.39364261119802, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.332627704725983 })); add(new LabeledPoint(2.7942279, new double[] { -0.751957286017338, 0.316843561689933, -1.99674219506348, 0.911736065044475, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.8063861, new double[] { -0.685277652430997, 1.28214038482516, 0.823898478545609, 0.232904453212813, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.155348103855541 })); add(new LabeledPoint(2.8124102, new double[] { -0.244991501432929, 0.51882005949686, -0.384947524429713, 0.823246560137838, -0.522940888712441, -0.863171185425945, 0.342627053981254, 0.553770299626224 })); add(new LabeledPoint(2.8419982, new double[] { -0.75731027755674, 2.09041984898851, 1.22684714620405, 1.53428167116843, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.8535925, new double[] { 1.20962937075363, -0.242882661178889, 1.09253092365124, -1.02470580167082, -0.522940888712441, 1.24263233939889, 3.11219574032972, 2.50384590920108 })); add(new LabeledPoint(2.9204698, new double[] { 0.570886990493502, 0.58243883987948, 0.555266033439982, 1.16006887775962, -0.522940888712441, 1.07357183940747, 0.342627053981254, 1.61744790484887 })); add(new LabeledPoint(2.9626924, new double[] { 0.719758684343624, 0.984970304132004, 1.09253092365124, 1.52137230773457, -0.522940888712441, -0.179808625688859, 0.342627053981254, -0.509907305596424 })); add(new LabeledPoint(2.9626924, new double[] { -1.52406140158064, 1.81975700990333, 0.689582255992796, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(2.9729753, new double[] { -0.132431544081234, 2.68769877553723, 1.09253092365124, 1.53428167116843, -0.522940888712441, -0.442797990776478, 0.342627053981254, -0.687186906466865 })); add(new LabeledPoint(3.0130809, new double[] { 0.436161292804989, -0.0834447307428255, -0.519263746982526, -1.02470580167082, 1.89254797819741, 1.07357183940747, 0.342627053981254, 1.26288870310799 })); add(new LabeledPoint(3.0373539, new double[] { -0.161195191984091, -0.671900359186746, 1.7641120364153, 1.13650173283446, -0.522940888712441, -0.863171185425945, 0.342627053981254, 0.0219314970149 })); add(new LabeledPoint(3.2752562, new double[] { 1.39927182372944, 0.513852869452676, 0.689582255992796, -1.02470580167082, 1.89254797819741, 1.49394503405693, 0.342627053981254, -0.155348103855541 })); add(new LabeledPoint(3.3375474, new double[] { 1.51967002306341, -0.852203755696565, 0.555266033439982, -0.104527297798983, 1.89254797819741, 1.85927724828569, 0.342627053981254, 0.908329501367106 })); add(new LabeledPoint(3.3928291, new double[] { 0.560725834706224, 1.87867703391426, 1.09253092365124, 1.39364261119802, -0.522940888712441, 0.486423065822545, 0.342627053981254, 1.26288870310799 })); add(new LabeledPoint(3.4355988, new double[] { 1.00765532502814, 1.69426310090641, 1.89842825896812, 1.53428167116843, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.509907305596424 })); add(new LabeledPoint(3.4578927, new double[] { 1.10152996153577, -0.10927271844907, 0.689582255992796, -1.02470580167082, 1.89254797819741, 1.97630171771485, 0.342627053981254, 1.61744790484887 })); add(new LabeledPoint(3.5160131, new double[] { 0.100001934217311, -1.30380956369388, 0.286633588334355, 0.316555063757567, -0.522940888712441, 0.28786643052924, 0.342627053981254, 0.553770299626224 })); add(new LabeledPoint(3.5307626, new double[] { 0.987291634724086, -0.36279314978779, -0.922212414640967, 0.232904453212813, -0.522940888712441, 1.79270085261407, 0.342627053981254, 1.26288870310799 })); add(new LabeledPoint(3.5652984, new double[] { 1.07158528137575, 0.606453149641961, 1.7641120364153, -0.432854616994416, 1.89254797819741, 0.528504607720369, 0.342627053981254, 0.199211097885341 })); add(new LabeledPoint(3.5876769, new double[] { 0.180156323255198, 0.188987436375017, -0.519263746982526, 1.09956763075594, -0.522940888712441, 0.708239632330506, 0.342627053981254, 0.199211097885341 })); add(new LabeledPoint(3.6309855, new double[] { 1.65687973755377, -0.256675483533719, 0.018001143228728, -1.02470580167082, 1.89254797819741, 1.79270085261407, 0.342627053981254, 1.26288870310799 })); add(new LabeledPoint(3.6800909, new double[] { 0.5720085322365, 0.239854450210939, -0.787896192088153, 1.0605418233138, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(3.7123518, new double[] { 0.323806133438225, -0.606717660886078, -0.250631301876899, -1.02470580167082, 1.89254797819741, 0.342907418101747, 0.342627053981254, 0.199211097885341 })); add(new LabeledPoint(3.9843437, new double[] { 1.23668206715898, 2.54220539083611, 0.152317365781542, -1.02470580167082, 1.89254797819741, 1.89037692416194, 0.342627053981254, 1.26288870310799 })); add(new LabeledPoint(3.993603, new double[] { 0.180156323255198, 0.154448192444669, 1.62979581386249, 0.576050411655607, 1.89254797819741, 0.708239632330506, 0.342627053981254, 1.79472750571931 })); add(new LabeledPoint(4.029806, new double[] { 1.60906277046565, 1.10378605019827, 0.555266033439982, -1.02470580167082, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 })); add(new LabeledPoint(4.1295508, new double[] { 1.0036214996026, 0.113496885050331, -0.384947524429713, 0.860016436332751, 1.89254797819741, -0.863171185425945, 0.342627053981254, -0.332627704725983 })); add(new LabeledPoint(4.3851468, new double[] { 1.25591974271076, 0.577607033774471, 0.555266033439982, -1.02470580167082, 1.89254797819741, 1.07357183940747, 0.342627053981254, 1.26288870310799 })); add(new LabeledPoint(4.6844434, new double[] { 2.09650591351268, 0.625488598331018, -2.66832330782754, -1.02470580167082, 1.89254797819741, 1.67954222367555, 0.342627053981254, 0.553770299626224 })); add(new LabeledPoint(5.477509, new double[] { 1.30028987435881, 0.338383613253713, 0.555266033439982, 1.00481276295349, 1.89254797819741, 1.24263233939889, 0.342627053981254, 1.97200710658975 })); } }; static double[] thetas = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 }; final static int numIterations = 1000; final static double stepSize = 0.0000001; public static void stochasticSGD(){ for (int i = 0; i < numIterations; i++) { for (LabeledPoint point : datas) { for (int j = 0; j < thetas.length; j++) { double total = 0.0; for (int k = 0; k < thetas.length; k++) { total = thetas[k] * point.getFeatures()[k]; } double diff = total - point.getLabel(); thetas[j] = thetas[j] - stepSize * diff * point.getFeatures()[j]; } } } for(double theta : thetas){ System.out.println(theta); } double total = 0.0; for(LabeledPoint point : datas){ double sum = 0.0; for(int i = 0;i < point.getFeatures().length;i++){ sum += thetas[i] * point.getFeatures()[i]; } total += Math.pow(sum - point.getLabel(), 2); } System.out.println(total / datas.size()); } public static void batchSGD(){ for (int i = 0; i < numIterations; i++) { for (LabeledPoint point : datas) { for (int j = 0; j < thetas.length; j++) { double total = 0.0; for (LabeledPoint point2 : datas) { for (int k = 0; k < thetas.length; k++) { total = thetas[k] * point2.getFeatures()[k]; } } double diff = total - point.getLabel(); thetas[j] = thetas[j] - stepSize * diff * point.getFeatures()[j]; } } } for(double theta : thetas){ System.out.println(theta); } double total = 0.0; for(LabeledPoint point : datas){ double sum = 0.0; for(int i = 0;i < point.getFeatures().length;i++){ sum += thetas[i] * point.getFeatures()[i]; } total += Math.pow(sum - point.getLabel(), 2); } System.out.println(total / datas.size()); } public static void main(String[] args) { stochasticSGD(); System.out.println(); batchSGD(); } } class LabeledPoint { private double label; private double[] features; public LabeledPoint(double label, double[] features) { this.label = label; this.features = features; } public double getLabel() { return label; } public void setLabel(double label) { this.label = label; } public double[] getFeatures() { return features; } public void setFeatures(double[] features) { this.features = features; } }