机器学习基础(二)多元线性回归模型 分类: 机器学习 ...

变量多于两个时,线性回归模型就变成了多元线性回归模型:

\begin{displaymath}h_{\theta}(x) = \theta^Tx = \sum_{i=0}^n \theta_i x_i, \nonumber\end{displaymath}

代价函数为:

\begin{displaymath}J(\theta)=\frac{1}{2m}\sum_{i=1}^{m}\left(h_{\theta}(x^{(i)})-y^{(i)}\right)^{2}. \nonumber\end{displaymath}

线性回归模型的训练(就是用梯度下降法求解最小代价函数)需要注意一些问题:

1.$\theta$的值

$\theta$1和$\theta$2的值相差太多时,梯度下降法难以收敛

2.学习速率

代价函数应该是递减的,如果代价函数不减反增,那么很可能是学习速率太大,跳过了极小值。

3.代价函数

代价函数必须是凸的,否则一样存在难以收敛的问题

4.过拟合

特征参数选的过多了,也就是n过大,会使模型对于新的数据预测准确性降低。

解决方法:减少特征参数;合并线性相关项;正则化;


多元线性回归模型的求解也可以通过Normal Equation的求解(其实就是通过矩阵变换求解线性方程),求解Normal Equation的好处是不用担心特征参数相差过多,不过变量矩阵必须是可逆的。

\begin{displaymath}\theta=\left(X^{T}X\right)^{-1}X^{T}\vec{y}.\end{displaymath}


参考练习:http://openclassroom.stanford.edu/MainFolder/DocumentPage.php?course=MachineLearning&doc=exercises/ex3/ex3.html

x = load('ex3x.dat');
y = load('ex3y.dat');

x = [ones(size(x,1),1) x];
%Normal equation
theta = (x'*x)\x'*y
x1=[1 1650 3];
y=x1*theta


meanx = mean(x);
sigmax = std(x);
x(:,2) = (x(:,2)-meanx(2))./sigmax(2);
x(:,3) = (x(:,3)-meanx(3))./sigmax(3);

figure
itera_num = 100;
sample_num = size(x,1);
alpha = [0.01, 0.03, 0.1, 0.3, 1, 1.3];
plotstyle = {'b', 'r', 'g', 'k', 'b--', 'r--'};

theta_grad_descent = zeros(size(x(1,:)));
for alpha_i = 1:length(alpha) 
    theta = zeros(size(x,2),1);
    Jtheta = zeros(itera_num, 1);
    for i = 1:itera_num     
        Jtheta(i) = (1/(2*sample_num)).*(x*theta-y)'*(x*theta-y);
        grad = (1/sample_num).*x'*(x*theta-y);
        theta = theta - alpha(alpha_i).*grad;
    end
    plot(0:49, Jtheta(1:50),char(plotstyle(alpha_i)),'LineWidth', 2)
    hold on
end
legend('0.01','0.03','0.1','0.3','1','1.3');

  theta = zeros(size(x,2),1);
  for i = 1:itera_num     
       grad = (1/sample_num).*x'*(x*theta-y);
       theta = theta - 1.*grad;
  end
theta
x2=[1 1650 3];
x2(:,2)=x2(:,2)*sigmax(2)+meanx(2);
x2(:,3)=x2(:,3)*sigmax(3)+meanx(3);
y=x2*theta

机器学习基础(二)多元线性回归模型 分类: 机器学习 ..._第1张图片


参考资料http://openclassroom.stanford.edu/MainFolder/CoursePage.php?course=MachineLearning


版权声明:本文为博主原创文章,未经博主允许不得转载。

转载于:https://www.cnblogs.com/learnordie/p/4656971.html

你可能感兴趣的:(数据结构与算法,人工智能,php)