逻辑斯谛回归模型其实是一种分类模型,这里实现的是参考李航的《统计机器学习》以及周志华的《机器学习》两本教材来整理实现的。
假定我们的输入为 x x x, x x x 可以是多个维度的,我们想要根据 x x x 去预测 y y y, y ∈ { 0 , 1 } y\in \{0,1\} y∈{0,1}。逻辑斯谛的模型如下:
p ( Y = 1 ∣ x ) = e x p ( w ⋅ x ) 1 + e x p ( w ⋅ x ) (1) p(Y=1|x)=\frac{exp(w\cdot x)}{1+exp(w\cdot x)}\tag{1} p(Y=1∣x)=1+exp(w⋅x)exp(w⋅x)(1)
其中的参数 w w w就是我们要进行学习的,注意:它是包含了权重系数和偏置(bias)b的。在书写程序时,这样表示更加简洁。
参数 w w w是我们需要学习的,我们采用极大似然法估计模型参数。
设:
P ( Y = 1 ∣ x ) = π ( x ) , P ( Y = 0 ∣ x ) = 1 − π ( x ) (2) P(Y=1|x)=\pi(x),\quad P(Y=0|x)=1-\pi(x)\tag{2} P(Y=1∣x)=π(x),P(Y=0∣x)=1−π(x)(2)
似然函数为:
∏ i = 1 N [ π ( x i ) ] y i [ 1 − π ( x i ) ] 1 − y i (3) \prod_{i=1}^N[\pi(x_i)]^{y_i}[1-\pi(x_i)]^{1-y_i} \tag{3} i=1∏N[π(xi)]yi[1−π(xi)]1−yi(3)
因为这种指数的形式不利于求导我们需要将它们转化为对数的形式,如下:
L ( w ) = ∑ i = 1 N [ y i l o g π ( x i ) + ( 1 − y i ) l o g ( 1 − π ( x i ) ) ] = ∑ i = 1 N [ y i l o g ( π ( x i ) 1 − π ( x i ) ) + l o g ( 1 − π ( x i ) ) ] = ∑ i = 1 N [ y i ( w ⋅ x i ) − l o g ( 1 + e x p ( w ⋅ x i ) ) ] (4) \begin{aligned} L(w)=&\sum_{i=1}^N[y_ilog\pi(x_i)+(1-y_i)log(1-\pi(x_i))] \\ =&\sum_{i=1}^N [y_ilog(\frac{\pi(x_i)}{1-\pi(x_i)})+log(1-\pi(x_i))]\\ =&\sum_{i=1}^{N}[y_i(w\cdot x_i)-log(1+exp(w\cdot x_i))] \end{aligned} \tag{4} L(w)===i=1∑N[yilogπ(xi)+(1−yi)log(1−π(xi))]i=1∑N[yilog(1−π(xi)π(xi))+log(1−π(xi))]i=1∑N[yi(w⋅xi)−log(1+exp(w⋅xi))](4)
对 L ( w ) L(w) L(w)求极大值,得到 w w w的估计值。
梯度下降法是求极小值的,而我们想要得到的是 L ( w ) L(w) L(w)的最大值,因此,我们取 L ( w ) L(w) L(w)的相反数,即:
arg min w − L ( w ) (5) \argmin_{w}-L(w) \tag{5} wargmin−L(w)(5)
对 L ( w ) L(w) L(w)关于 w w w求导,如下:
( − L ( w ) ) ′ = − ∑ i = 1 N [ ( y i ⋅ x i ) − e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) ⋅ x i ] = − ∑ i = 1 N [ ( y i − e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) ) ⋅ x i ] = ∑ i = 1 N [ ( e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) − y i ) ⋅ x i ] (6) \begin{aligned} (-L(w))'=&-\sum_{i=1}^N[(y_i\cdot x_i)-\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)}\cdot x_i]\\ =&-\sum_{i=1}^N[(y_i-\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)})\cdot x_i]\\ =&\sum_{i=1}^N[(\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)}-y_i)\cdot x_i] \end{aligned} \tag{6} (−L(w))′===−i=1∑N[(yi⋅xi)−1+exp(w⋅x)exp(w⋅xi)⋅xi]−i=1∑N[(yi−1+exp(w⋅x)exp(w⋅xi))⋅xi]i=1∑N[(1+exp(w⋅x)exp(w⋅xi)−yi)⋅xi](6)
然后我们就得到了参数 w w w的更新公式,如下:
w ′ = w − l r ⋅ ( − L ( w ) ′ ) = w − l r ⋅ ( ∑ i = 1 N [ ( e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) − y i ) ⋅ x i ] ) (7) \begin{aligned} w'=&w-lr\cdot(-L(w)')\\ =&w-lr\cdot(\sum_{i=1}^N[(\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)}-y_i)\cdot x_i]) \end{aligned} \tag{7} w′==w−lr⋅(−L(w)′)w−lr⋅(i=1∑N[(1+exp(w⋅x)exp(w⋅xi)−yi)⋅xi])(7)
关于优化方法的选择,最开始是选择西瓜书上提供的牛顿法来实现的,牛顿法的好处是,可以获得较快的收敛速度,但是坏处是,当海森矩阵为奇异矩阵时,会出现无法求解的情况。
因此,可以采用拟牛顿法进行优化,在解决这个问题的同时,也可以很快的收敛。
但是,自己对拟牛顿法并不熟悉,而梯度下降法虽然收敛可能较慢,但是实现起来较为简单,因此这里采用了梯度下降法来优化似然函数。
package weka.classifiers.myf;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Standardize;
import java.util.Arrays;
/**
* @author YFMan
* @Description 自定义的 Logistic 回归分类器
* @Date 2023/6/13 11:02
*/
public class myLogistic extends Classifier {
// 用于存储 线性回归 系数 的数组
private double[] m_Coefficients;
// 类别索引
private int m_ClassIndex;
// 牛顿法的迭代次数
private int m_MaxIterations = 1000;
// 属性数量
private int m_numAttributes;
// 系数数量
private int m_numCoefficients;
// 梯度下降步长
private double m_lr = 1e-4;
// 标准化数据的过滤器
public static final int FILTER_STANDARDIZE = 1;
// 用于标准化数据的过滤器
protected Filter m_StandardizeFilter = null;
// 用于将 normal 转为 binary 的过滤器
protected Filter m_NormalToBinaryFilter = null;
/*
* @Author YFMan
* @Description 采用牛顿法来训练 logistic 回归模型
* @Date 2023/5/9 22:08
* @Param [data] 训练数据
* @return void
**/
public void buildClassifier(Instances data) throws Exception {
// 设置类别索引
m_ClassIndex = data.classIndex();
// 设置属性数量
m_numAttributes = data.numAttributes();
// 系数数量 = 输入属性数量 + 1(截距参数b)
m_numCoefficients = m_numAttributes;
// 初始化 系数数组
m_Coefficients = new double[m_numCoefficients];
Arrays.fill(m_Coefficients, 0);
// 将输入数据进行标准化
m_StandardizeFilter = new Standardize();
m_StandardizeFilter.setInputFormat(data);
data = Filter.useFilter(data, m_StandardizeFilter);
// 将类别属性转为二值属性
m_NormalToBinaryFilter = new NominalToBinary();
m_NormalToBinaryFilter.setInputFormat(data);
data = Filter.useFilter(data, m_NormalToBinaryFilter);
// 梯度下降法
for(int curPerformIteration = 0; curPerformIteration < m_MaxIterations;curPerformIteration++){
double[] deltaM_Coefficients = new double[m_numCoefficients];
// 计算 l(w) 的一阶导数
for(int i = 0;i<data.numInstances();i++){
double yi = data.instance(i).value(m_ClassIndex);
double wxi = 0;
int column = 0;
for(int j=0;j<m_numAttributes;j++){
if(j!=m_ClassIndex){
wxi += m_Coefficients[column] * data.instance(i).value(j);
column++;
}
}
// 加上截距参数 b
wxi += m_Coefficients[column];
double pi1 = Math.exp(wxi) / (1 + Math.exp(wxi));
for(int k=0;k<m_numCoefficients - 1;k++){
deltaM_Coefficients[k] += m_lr * (pi1 - yi) * data.instance(i).value(k);
}
// 这里计算 bias b 对应的更新量
deltaM_Coefficients[m_numCoefficients - 1] += m_lr * (pi1 - yi);
}
// 进行参数更新
for(int k=0;k<m_numCoefficients;k++){
m_Coefficients[k] -= deltaM_Coefficients[k];
}
// 如果参数更新量小于阈值,则停止迭代
double delta = 0;
for(int k=0;k<m_numCoefficients;k++){
delta += deltaM_Coefficients[k] * deltaM_Coefficients[k];
}
if(delta < 1e-6){
break;
}
}
}
/*
* @Author YFMan
* @Description // 分类实例
* @Date 2023/6/16 11:17
* @Param [instance]
* @return double[]
**/
public double[] distributionForInstance(Instance instance) throws Exception {
// 将输入数据进行标准化
m_StandardizeFilter.input(instance);
instance = m_StandardizeFilter.output();
// 将输入属性二值化
m_NormalToBinaryFilter.input(instance);
instance = m_NormalToBinaryFilter.output();
double[] result = new double[2];
result[0] = 0;
result[1] = 0;
int column = 0;
for(int i=0;i<m_numAttributes;i++){
if(m_ClassIndex != i){
result[0] += instance.value(i) * m_Coefficients[column];
column++;
}
}
result[0] += m_Coefficients[column];
result[0] = 1 / (1 + Math.exp(result[0]));
result[1] = 1 - result[0];
return result;
}
/*
* @Author YFMan
* @Description 主函数 生成一个线性回归函数预测器
* @Date 2023/5/9 22:35
* @Param [argv]
* @return void
**/
public static void main(String[] argv) {
runClassifier(new myLogistic(), argv);
}
}