【机器学习实战】SGD 梯度下降

一、梯度下降简介

http://baike.baidu.com/link?url=JRP2bhxuJzeawEEEXgqNRavoZBlGxbm4QW0EEoPz1epO4DFoTRnvssRAK0i2P95yMg8EBa_jf3qU_9u1JnQ1k_


梯度下降法是一个最优化算法。是求解无约束优化问题最简单和最古老的方法之一。

例子:

f(x)=x^2的最小值

利用梯度下降的方法解题步骤:

  1. 求梯度;grad=2x ===> 梯度是求导(多元就是求偏导),几何意义就是某点变化最快的方向求最小值,就是梯度的相反方向移动

  2. x := x - step * grad ==> step为步长,就学习速率,如果步长足够下,保证了每一次迭代都在减小,但是收敛慢。如果步长太大,不能保证每一次迭代都减少,也不能保证收敛。可能在最优解出波动。

  3. 循环第2步.直到x的变化使得f(x)在两次的迭代之间的差值足够小。比如0.00000001,也就是说两次迭代f(x)基本没有变化。则说明f(x)已经达到了局部最小值。

package examples.mllib;

import java.text.DecimalFormat;

/**
 * @function ==> f(x) = x^2
 * @result => figure out the smallest of f(x)
 * @learning_rate => 0.1
 * @result_conditions => difference between every change is samller than
 *                    0.00000001
 * @author Harry Wu
 *
 */
public class JavaSGDExample {

	final static DecimalFormat df = new DecimalFormat("0.0000");
	
	static double func(double x) {
		return Math.pow(x, 2);
	}

	static double grad(double x) {
		return 2 * x;
	}

	public static void main(String[] args) {
		double step = 0.1;
		double x = 2;
		int k = 0;
		double change = func(x);
		double current = func(x);
		final double returnThreshold = 0.00000001;

		while (change > returnThreshold) {
			x = x - step * grad(x);
			change = current - func(x);
			current = func(x);
			k += 1;
		}
		
		System.out.println(df.format(x));
	}
}


二、多元

http://www.lxway.com/562680451.htm

package examples.mllib;

import java.util.ArrayList;

/**
 * spark-1.5.1-bin-hadoop2.6/data/mllib/ridge-data/lpsa.data
 * 
 * @author Harry Wu
 *
 */
public class JavaLinearRegressionWithSGDExample2 {

	static ArrayList<LabeledPoint> datas = new ArrayList<LabeledPoint>() {
		{
			add(new LabeledPoint(-0.4307829,
					new double[] { -1.63735562648104, -2.00621178480549, -1.86242597251066, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(-0.1625189,
					new double[] { -1.98898046126935, -0.722008756122123, -0.787896192088153, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(-0.1625189,
					new double[] { -1.57881887548545, -2.1887840293994, 1.36116336875686, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, 0.342627053981254, -0.155348103855541 }));
			add(new LabeledPoint(-0.1625189,
					new double[] { -2.16691708463163, -0.807993896938655, -0.787896192088153, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(0.3715636,
					new double[] { -0.507874475300631, -0.458834049396776, -0.250631301876899, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(0.7654678,
					new double[] { -2.03612849966376, -0.933954647105133, -1.86242597251066, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(0.8544153,
					new double[] { -0.557312518810673, -0.208756571683607, -0.787896192088153, 0.990146852537193,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.2669476,
					new double[] { -0.929360463147704, -0.0578991819441687, 0.152317365781542, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.2669476, new double[] { -2.28833047634983, -0.0706369432557794, -0.116315079324086,
					0.80409888772376, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.2669476, new double[] { 0.223498042876113, -1.41471935455355, -0.116315079324086,
					-1.02470580167082, -0.522940888712441, -0.29928234305568, 0.342627053981254, 0.199211097885341 }));
			add(new LabeledPoint(1.3480731,
					new double[] { 0.107785900236813, -1.47221551299731, 0.420949810887169, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, 0.342627053981254, -0.687186906466865 }));
			add(new LabeledPoint(1.446919,
					new double[] { 0.162180092313795, -1.32557369901905, 0.286633588334355, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.4701758, new double[] { -1.49795329918548, -0.263601072284232, 0.823898478545609,
					0.788388310173035, -0.522940888712441, -0.29928234305568, 0.342627053981254, 0.199211097885341 }));
			add(new LabeledPoint(1.4929041, new double[] { 0.796247055396743, 0.0476559407005752, 0.286633588334355,
					-1.02470580167082, -0.522940888712441, 0.394013435896129, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.5581446,
					new double[] { -1.62233848461465, -0.843294091975396, -3.07127197548598, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.5993876, new double[] { -0.990720665490831, 0.458513517212311, 0.823898478545609,
					1.07379746308195, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.6389967,
					new double[] { -0.171901281967138, -0.489197399065355, -0.65357996953534, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.6956156,
					new double[] { -1.60758252338831, -0.590700340358265, -0.65357996953534, -0.619561070667254,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.7137979, new double[] { 0.366273918511144, -0.414014962912583, -0.116315079324086,
					0.232904453212813, -0.522940888712441, 0.971228997418125, 0.342627053981254, 1.26288870310799 }));
			add(new LabeledPoint(1.8000583, new double[] { -0.710307384579833, 0.211731938156277, 0.152317365781542,
					-1.02470580167082, -0.522940888712441, -0.442797990776478, 0.342627053981254, 1.61744790484887 }));
			add(new LabeledPoint(1.8484548, new double[] { -0.262791728113881, -1.16708345615721, 0.420949810887169,
					0.0846342590816532, -0.522940888712441, 0.163172393491611, 0.342627053981254, 1.97200710658975 }));
			add(new LabeledPoint(1.8946169, new double[] { 0.899043117369237, -0.590700340358265, 0.152317365781542,
					-1.02470580167082, -0.522940888712441, 1.28643254437683, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(1.9242487, new double[] { -0.903451690500615, 1.07659722048274, 0.152317365781542,
					1.28380453408541, -0.522940888712441, -0.442797990776478, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.008214, new double[] { -0.0633337899773081, -1.38088970920094, 0.958214701098423,
					0.80409888772376, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.0476928,
					new double[] { -1.15393789990757, -0.961853075398404, -0.116315079324086, -1.02470580167082,
							-0.522940888712441, -0.442797990776478, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.1575593, new double[] { 0.0620203721138446, 0.0657973885499142, 1.22684714620405,
					-0.468824786336838, -0.522940888712441, 1.31421001659859, 1.72741139715549, -0.332627704725983 }));
			add(new LabeledPoint(2.1916535,
					new double[] { -0.75731027755674, -2.92717970468456, 0.018001143228728, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, 0.342627053981254, -0.332627704725983 }));
			add(new LabeledPoint(2.2137539, new double[] { 1.11226993252773, 1.06484916245061, 0.555266033439982,
					0.877691038550889, 1.89254797819741, 1.43890404648442, 0.342627053981254, 0.376490698755783 }));
			add(new LabeledPoint(2.2772673,
					new double[] { -0.468768642850639, -1.43754788774533, -1.05652863719378, 0.576050411655607,
							-0.522940888712441, 0.0120483832567209, 0.342627053981254, -0.687186906466865 }));
			add(new LabeledPoint(2.2975726, new double[] { -0.618884859896728, -1.1366360750781, -0.519263746982526,
					-1.02470580167082, -0.522940888712441, -0.863171185425945, 3.11219574032972, 1.97200710658975 }));
			add(new LabeledPoint(2.3272777, new double[] { -0.651431999123483, 0.55329161145762, -0.250631301876899,
					1.11210019001038, -0.522940888712441, -0.179808625688859, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.5217206, new double[] { 0.115499102435224, -0.512233676577595, 0.286633588334355,
					1.13650173283446, -0.522940888712441, -0.179808625688859, 0.342627053981254, -0.155348103855541 }));
			add(new LabeledPoint(2.5533438,
					new double[] { 0.266341329949937, -0.551137885443386, -0.384947524429713, 0.354857790686005,
							-0.522940888712441, -0.863171185425945, 0.342627053981254, -0.332627704725983 }));
			add(new LabeledPoint(2.5687881, new double[] { 1.16902610257751, 0.855491905752846, 2.03274448152093,
					1.22628985326088, 1.89254797819741, 2.02833774827712, 3.11219574032972, 2.68112551007152 }));
			add(new LabeledPoint(2.6567569, new double[] { -0.218972367124187, 0.851192298581141, 0.555266033439982,
					-1.02470580167082, -0.522940888712441, -0.863171185425945, 0.342627053981254, 0.908329501367106 }));
			add(new LabeledPoint(2.677591, new double[] { 0.263121415733908, 1.4142681068416, 0.018001143228728,
					1.35980653053822, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.7180005, new double[] { -0.0704736333296423, 1.52000996595417, 0.286633588334355,
					1.39364261119802, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.332627704725983 }));
			add(new LabeledPoint(2.7942279,
					new double[] { -0.751957286017338, 0.316843561689933, -1.99674219506348, 0.911736065044475,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.8063861,
					new double[] { -0.685277652430997, 1.28214038482516, 0.823898478545609, 0.232904453212813,
							-0.522940888712441, -0.863171185425945, 0.342627053981254, -0.155348103855541 }));
			add(new LabeledPoint(2.8124102, new double[] { -0.244991501432929, 0.51882005949686, -0.384947524429713,
					0.823246560137838, -0.522940888712441, -0.863171185425945, 0.342627053981254, 0.553770299626224 }));
			add(new LabeledPoint(2.8419982, new double[] { -0.75731027755674, 2.09041984898851, 1.22684714620405,
					1.53428167116843, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.8535925, new double[] { 1.20962937075363, -0.242882661178889, 1.09253092365124,
					-1.02470580167082, -0.522940888712441, 1.24263233939889, 3.11219574032972, 2.50384590920108 }));
			add(new LabeledPoint(2.9204698, new double[] { 0.570886990493502, 0.58243883987948, 0.555266033439982,
					1.16006887775962, -0.522940888712441, 1.07357183940747, 0.342627053981254, 1.61744790484887 }));
			add(new LabeledPoint(2.9626924, new double[] { 0.719758684343624, 0.984970304132004, 1.09253092365124,
					1.52137230773457, -0.522940888712441, -0.179808625688859, 0.342627053981254, -0.509907305596424 }));
			add(new LabeledPoint(2.9626924,
					new double[] { -1.52406140158064, 1.81975700990333, 0.689582255992796, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(2.9729753, new double[] { -0.132431544081234, 2.68769877553723, 1.09253092365124,
					1.53428167116843, -0.522940888712441, -0.442797990776478, 0.342627053981254, -0.687186906466865 }));
			add(new LabeledPoint(3.0130809, new double[] { 0.436161292804989, -0.0834447307428255, -0.519263746982526,
					-1.02470580167082, 1.89254797819741, 1.07357183940747, 0.342627053981254, 1.26288870310799 }));
			add(new LabeledPoint(3.0373539, new double[] { -0.161195191984091, -0.671900359186746, 1.7641120364153,
					1.13650173283446, -0.522940888712441, -0.863171185425945, 0.342627053981254, 0.0219314970149 }));
			add(new LabeledPoint(3.2752562, new double[] { 1.39927182372944, 0.513852869452676, 0.689582255992796,
					-1.02470580167082, 1.89254797819741, 1.49394503405693, 0.342627053981254, -0.155348103855541 }));
			add(new LabeledPoint(3.3375474, new double[] { 1.51967002306341, -0.852203755696565, 0.555266033439982,
					-0.104527297798983, 1.89254797819741, 1.85927724828569, 0.342627053981254, 0.908329501367106 }));
			add(new LabeledPoint(3.3928291, new double[] { 0.560725834706224, 1.87867703391426, 1.09253092365124,
					1.39364261119802, -0.522940888712441, 0.486423065822545, 0.342627053981254, 1.26288870310799 }));
			add(new LabeledPoint(3.4355988, new double[] { 1.00765532502814, 1.69426310090641, 1.89842825896812,
					1.53428167116843, -0.522940888712441, -0.863171185425945, 0.342627053981254, -0.509907305596424 }));
			add(new LabeledPoint(3.4578927, new double[] { 1.10152996153577, -0.10927271844907, 0.689582255992796,
					-1.02470580167082, 1.89254797819741, 1.97630171771485, 0.342627053981254, 1.61744790484887 }));
			add(new LabeledPoint(3.5160131, new double[] { 0.100001934217311, -1.30380956369388, 0.286633588334355,
					0.316555063757567, -0.522940888712441, 0.28786643052924, 0.342627053981254, 0.553770299626224 }));
			add(new LabeledPoint(3.5307626, new double[] { 0.987291634724086, -0.36279314978779, -0.922212414640967,
					0.232904453212813, -0.522940888712441, 1.79270085261407, 0.342627053981254, 1.26288870310799 }));
			add(new LabeledPoint(3.5652984, new double[] { 1.07158528137575, 0.606453149641961, 1.7641120364153,
					-0.432854616994416, 1.89254797819741, 0.528504607720369, 0.342627053981254, 0.199211097885341 }));
			add(new LabeledPoint(3.5876769, new double[] { 0.180156323255198, 0.188987436375017, -0.519263746982526,
					1.09956763075594, -0.522940888712441, 0.708239632330506, 0.342627053981254, 0.199211097885341 }));
			add(new LabeledPoint(3.6309855, new double[] { 1.65687973755377, -0.256675483533719, 0.018001143228728,
					-1.02470580167082, 1.89254797819741, 1.79270085261407, 0.342627053981254, 1.26288870310799 }));
			add(new LabeledPoint(3.6800909, new double[] { 0.5720085322365, 0.239854450210939, -0.787896192088153,
					1.0605418233138, -0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(3.7123518, new double[] { 0.323806133438225, -0.606717660886078, -0.250631301876899,
					-1.02470580167082, 1.89254797819741, 0.342907418101747, 0.342627053981254, 0.199211097885341 }));
			add(new LabeledPoint(3.9843437, new double[] { 1.23668206715898, 2.54220539083611, 0.152317365781542,
					-1.02470580167082, 1.89254797819741, 1.89037692416194, 0.342627053981254, 1.26288870310799 }));
			add(new LabeledPoint(3.993603, new double[] { 0.180156323255198, 0.154448192444669, 1.62979581386249,
					0.576050411655607, 1.89254797819741, 0.708239632330506, 0.342627053981254, 1.79472750571931 }));
			add(new LabeledPoint(4.029806,
					new double[] { 1.60906277046565, 1.10378605019827, 0.555266033439982, -1.02470580167082,
							-0.522940888712441, -0.863171185425945, -1.04215728919298, -0.864466507337306 }));
			add(new LabeledPoint(4.1295508, new double[] { 1.0036214996026, 0.113496885050331, -0.384947524429713,
					0.860016436332751, 1.89254797819741, -0.863171185425945, 0.342627053981254, -0.332627704725983 }));
			add(new LabeledPoint(4.3851468, new double[] { 1.25591974271076, 0.577607033774471, 0.555266033439982,
					-1.02470580167082, 1.89254797819741, 1.07357183940747, 0.342627053981254, 1.26288870310799 }));
			add(new LabeledPoint(4.6844434, new double[] { 2.09650591351268, 0.625488598331018, -2.66832330782754,
					-1.02470580167082, 1.89254797819741, 1.67954222367555, 0.342627053981254, 0.553770299626224 }));
			add(new LabeledPoint(5.477509, new double[] { 1.30028987435881, 0.338383613253713, 0.555266033439982,
					1.00481276295349, 1.89254797819741, 1.24263233939889, 0.342627053981254, 1.97200710658975 }));
		}
	};

	static double[] thetas = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
	final static int numIterations = 1000;
	final static double stepSize = 0.0000001;
	
	public static void stochasticSGD(){
		for (int i = 0; i < numIterations; i++) {
			for (LabeledPoint point : datas) {
				for (int j = 0; j < thetas.length; j++) {
					double total = 0.0;
					for (int k = 0; k < thetas.length; k++) {
						total = thetas[k] * point.getFeatures()[k];
					}
					double diff = total - point.getLabel();
					thetas[j] = thetas[j] - stepSize * diff * point.getFeatures()[j];
				}
			}
		}
		
		
		for(double theta : thetas){
			System.out.println(theta);
		}
		
		
		double total = 0.0;
		for(LabeledPoint point : datas){
			double sum = 0.0;
			for(int i = 0;i < point.getFeatures().length;i++){
				sum += thetas[i] * point.getFeatures()[i];
			}
			
			total += Math.pow(sum - point.getLabel(), 2);
		}
		
		System.out.println(total / datas.size());
	}
	
	public static void batchSGD(){
		for (int i = 0; i < numIterations; i++) {
			for (LabeledPoint point : datas) {
				for (int j = 0; j < thetas.length; j++) {
					double total = 0.0;
					for (LabeledPoint point2 : datas) {
						for (int k = 0; k < thetas.length; k++) {
							total = thetas[k] * point2.getFeatures()[k];
						}
					}
					
					double diff = total - point.getLabel();
					thetas[j] = thetas[j] - stepSize * diff * point.getFeatures()[j];
				}
			}
		}
		
		
		for(double theta : thetas){
			System.out.println(theta);
		}
		
		
		double total = 0.0;
		for(LabeledPoint point : datas){
			double sum = 0.0;
			for(int i = 0;i < point.getFeatures().length;i++){
				sum += thetas[i] * point.getFeatures()[i];
			}
			
			total += Math.pow(sum - point.getLabel(), 2);
		}
		
		System.out.println(total / datas.size());
	}

	public static void main(String[] args) {
		stochasticSGD();
		System.out.println();
		batchSGD();
	}
}

class LabeledPoint {
	private double label;
	private double[] features;

	public LabeledPoint(double label, double[] features) {
		this.label = label;
		this.features = features;
	}

	public double getLabel() {
		return label;
	}

	public void setLabel(double label) {
		this.label = label;
	}

	public double[] getFeatures() {
		return features;
	}

	public void setFeatures(double[] features) {
		this.features = features;
	}
}


你可能感兴趣的:(【机器学习实战】SGD 梯度下降)