大家好,本文主要和大家分享coursera网站上斯坦福大学机器学习公开课(吴文达老师)第六周Reguirized Linera Regression and Bias v.s Variance的课后编程题。这周的课程主要介绍了一些模型选择,样本分类,bias and variance,regularization参数选择,learing curves学习曲线绘制等知识。以下为本人编写的代码和一些个人的见解,如有错误,请留言批评指正,谢谢。
1.首先,我们需要对linearRegCostFunction.m这个文件进行改造。这个文件是利用线性拟合为模型,计算样本X,y在指定theta和lambda的参下的代价J和偏导grad。
function [J, grad] = linearRegCostFunction(X, y, theta, lambda) % You need to return the following variables correctly J = 0; grad = zeros(size(theta)); J = sum((X * theta - y).^2)/(2*m) + lambda / (2 * m) * (sum(theta .^2 ) - theta(1)^2); %计算正则化线性回归的代价 grad = ((X * theta - y)' * X)'/m + lambda / m * theta; %计算正则化线性回归的梯度 grad(1) = grad(1) - lambda / m * theta(1); %不惩罚theta0,重新加回去,theta0在matlab数组中的下标为1 grad = grad(:); end
function [error_train, error_val] = ... learningCurve(X, y, Xval, yval, lambda) m = size(X, 1); % You need to return these values correctly error_train = zeros(m, 1); error_val = zeros(m, 1); %需要绘制的是error_training,error_val随着训练样本集增大的变换曲线 %所以训练数m需要从1开始慢慢增加,所以需要for循环,i表示当前训练样本数 for i = 1:m, theta = trainLinearReg(X(1:i,:),y(1:i,:),lambda); %注意,此时训练样本数为i,所以用X,y的前i个元素训练theta error_train(i) = sum((X(1:i,:) * theta - y(1:i,:)).^2)/(2 * i); %注意,此时的训练样本数为i,所以用X,y的前i个元素计算error_train,除数为2*i error_val(i) = sum((Xval * theta - yval).^2)/(2 * size(Xval,1)); %注意,由于验证集不变,每次用的都是全部的验证集 end end
function [X_poly] = polyFeatures(X, p) X_poly = zeros(numel(X), p); for i = 1:p, X_poly(:,i) = X.^i; %第i列对应X的i次方即可 end end
function [lambda_vec, error_train, error_val] = ... validationCurve(X, y, Xval, yval) % Selected values of lambda (you should not change this) lambda_vec = [0 0.001 0.003 0.01 0.03 0.1 0.3 1 3 10]'; % You need to return these variables correctly. error_train = zeros(length(lambda_vec), 1); error_val = zeros(length(lambda_vec), 1); for i = 1:10, %因为有lambda_vec有十个元素,需要对每个元素逐个计算,故用for循环 theta = trainLinearReg(X,y,lambda_vec(i)); %利用正则化的线性回归计算最优参数theta error_train(i) = sum((X * theta - y).^2)/(2 * size(X,1)); %计算训练误差error_train error_val(i) = sum((Xval * theta - yval).^2)/(2 * size(Xval,1)); %计算验证误差error_vala end