logistic 回归,虽然名字里有 “回归” 二字,但实际上是解决分类问题的一类线性模型。在某些文献中,logistic 回归又被称作 logit 回归,maximum-entropy classification(MaxEnt,最大熵分类),或 log-linear classifier(对数线性分类器)。该模型利用函数 logistic function 将单次试验(single trial)的可能结果输出为概率。
scikit-learn 中 logistic 回归在 LogisticRegression 类中实现了二分类(binary)、一对多分类(one-vs-rest)及多项式 logistic 回归,并带有可选的 L1 和 L2 正则化。
注意,scikit-learn的逻辑回归在默认情况下使用L2正则化,这样的方式在机器学习领域是常见的,在统计分析领域是不常见的。正则化的另一优势是提升数值稳定性。scikit-learn通过将C设置为很大的值实现无正则化。
作为优化问题,带 L2罚项的二分类 logistic 回归要最小化以下代价函数(cost function):
m i n w , c 1 2 w T w + C ∑ i = 1 n log ( exp ( − y i ( X i T w + c ) ) + 1 ) . \underset{w, c}{min\,} \frac{1}{2}w^T w + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) . w,cmin21wTw+Ci=1∑nlog(exp(−yi(XiTw+c))+1).
类似地,带 L1 正则的 logistic 回归解决的是如下优化问题:
m i n w , c ∥ w ∥ 1 + C ∑ i = 1 n log ( exp ( − y i ( X i T w + c ) ) + 1 ) . \underset{w, c}{min\,} \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) . w,cmin∥w∥1+Ci=1∑nlog(exp(−yi(XiTw+c))+1).
Elastic-Net正则化是L1 和 L2的组合,来使如下代价函数最小:
min w , c 1 − ρ 2 w T w + ρ ∥ w ∥ 1 + C ∑ i = 1 n log ( exp ( − y i ( X i T w + c ) ) + 1 ) . \min_{w, c} \frac{1 - \rho}{2}w^T w + \rho \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) . w,cmin21−ρwTw+ρ∥w∥1+Ci=1∑nlog(exp(−yi(XiTw+c))+1).
其中ρ控制正则化L1与正则化L2的强度(对应于l1_ratio参数)。
注意,在这个表示法中,假定目标y_i在测试时应属于集合[-1,1]。我们可以发现Elastic-Net在ρ=1时与L1罚项等价,在ρ=0时与L2罚项等价
在 LogisticRegression 类中实现了这些优化算法: liblinear, newton-cg, lbfgs, sag 和 saga。
liblinear应用了坐标下降算法(Coordinate Descent, CD),并基于 scikit-learn 内附的高性能 C++ 库 LIBLINEAR library 实现。不过 CD 算法训练的模型不是真正意义上的多分类模型,而是基于 “one-vs-rest” 思想分解了这个优化问题,为每个类别都训练了一个二元分类器。因为实现在底层使用该求解器的 LogisticRegression 实例对象表面上看是一个多元分类器。 sklearn.svm.l1_min_c 可以计算使用 L1时 C 的下界,以避免模型为空(即全部特征分量的权重为零)。
lbfgs, sag 和 newton-cg 求解器只支持 L2罚项以及无罚项,对某些高维数据收敛更快。这些求解器的参数 multi_class设为 multinomial 即可训练一个真正的多项式 logistic 回归 [5] ,其预测的概率比默认的 “one-vs-rest” 设定更为准确。
sag 求解器基于平均随机梯度下降算法(Stochastic Average Gradient descent) [6]。在大数据集上的表现更快,大数据集指样本量大且特征数多。
saga 求解器 [7] 是 sag 的一类变体,它支持非平滑(non-smooth)的 L1 正则选项 penalty=“l1” 。因此对于稀疏多项式 logistic 回归 ,往往选用该求解器。saga求解器是唯一支持弹性网络正则选项的求解器。
lbfgs是一种近似于Broyden–Fletcher–Goldfarb–Shanno算法[8]的优化算法,属于准牛顿法。lbfgs求解器推荐用于较小的数据集,对于较大的数据集,它的性能会受到影响。[9]
总的来说,各求解器特点如下:
罚项 | liblinear | lbfgs | newton-cg | sag | saga |
---|---|---|---|---|---|
多项式损失+L2罚项 | × | √ | √ | √ | √ |
一对剩余(One vs Rest) + L2罚项 | √ | √ | √ | √ | √ |
多项式损失 + L1罚项 | × | × | × | × | √ |
一对剩余(One vs Rest) + L1罚项 | √ × | × | × | √ | |
弹性网络 | × | × | × | × | √ |
无罚项 | × | √ | √ | √ | √ |
表现 | |||||
惩罚偏置值(差) | √ | × | × | × | × |
大数据集上速度快 | × | × | × | √ | √ |
未缩放数据集上鲁棒 | √ | √ | √ | × | × |
默认情况下,lbfgs求解器鲁棒性占优。对于大型数据集,saga求解器通常更快。对于大数据集,还可以用 SGDClassifier ,并使用对数损失(log loss)这可能更快,但需要更多的调优。
示例:
- Logistic回归中的L1罚项和稀疏系数
- L1罚项-logistic回归的路径
- 多项式和OVR的Logistic回归
- newgroups20上的多类稀疏Logistic回归
- 使用多项式Logistic回归和L1进行MNIST数据集的分类
与 liblinear 的区别:
当 fit_intercept=False 拟合得到的 coef_ 或者待预测的数据为零时,用 solver=liblinear 的 LogisticRegression 或 LinearSVC 与直接使用外部 liblinear 库预测得分会有差异。这是因为, 对于 decision_function 为零的样本, LogisticRegression 和 LinearSVC 将预测为负类,而 liblinear 预测为正类。
注意,设定了 fit_intercept=False ,又有很多样本使得 decision_function 为零的模型,很可能会欠拟合,其表现往往比较差。建议您设置 fit_intercept=True 并增大 intercept_scaling 。
注意:利用稀疏 logistic 回归进行特征选择
带 L1罚项的 logistic 回归 将得到稀疏模型(sparse model),相当于进行了特征选择(feature selection),详情参见 基于 L1 的特征选取。
LogisticRegressionCV 对 logistic 回归 的实现内置了交叉验证(cross-validation),可以找出最优的 C和l1_ratio参数 。newton-cg, sag, saga 和 lbfgs 在高维数据上更快,这是因为采用了热启动(warm-starting)。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
data = pd.read_csv('ex2data1.txt',sep = ',',names = ['exam1', 'exam2','admitted'])
data.head()
exam1 | exam2 | admitted | |
---|---|---|---|
0 | 34.623660 | 78.024693 | 0 |
1 | 30.286711 | 43.894998 | 0 |
2 | 35.847409 | 72.902198 | 0 |
3 | 60.182599 | 86.308552 | 1 |
4 | 79.032736 | 75.344376 | 1 |
# 选取exam1和exam2列,并转为array类型
X = data.iloc[:,[0,1]].values
y = data.iloc[:,2].values
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 0, test_size = 0.25)
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)
LogisticRegressio官方指导文档
from sklearn.linear_model import LogisticRegression
classfier = LogisticRegression()
classfier.fit(X_train, y_train)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
intercept_scaling=1, l1_ratio=None, max_iter=100,
multi_class='warn', n_jobs=None, penalty='l2',
random_state=None, solver='warn', tol=0.0001, verbose=0,
warm_start=False)
y_pred = classfier.predict(X_test)
我们预测了测试集。 现在我们将评估逻辑回归模型是否正确的学习和理解。因此这个混淆矩阵将包含我们模型的正确和错误的预测。
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
from matplotlib.colors import ListedColormap
X_set, y_set = X_train, y_train
X1, X2 = np.meshgrid(np.arange(start=X_set[:,0].min()-1, stop = X_set[:,0].max()+1, step=0.01),
np.arange(start=X_set[:,1].min()-1, stop = X_set[:,1].max()+1, step=0.01))
plt.contourf(X1, X2, classfier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
alpha = 0.75, cmap = ListedColormap(('red', 'green')))
plt.xlim(X1.min(), X1.max())
plt.ylim(X2.min(), X2.max())
for i, j in enumerate(np.unique(y_set)):
plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
c = ListedColormap(('red','green'))(i),label=j)
plt.title('LOGISTIC(Training set)')
plt.xlabel('Exam 1 score')
plt.ylabel('Exam 2 score')
plt.legend()
plt.show()
X_set, y_set = X_test, y_test
X1, X2 = np.meshgrid(np.arange(start=X_set[:,0].min()-1, stop = X_set[:,0].max()+1, step=0.01),
np.arange(start=X_set[:,1].min()-1, stop = X_set[:,1].max()+1, step=0.01))
plt.contourf(X1, X2, classfier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),
alpha = 0.75, cmap = ListedColormap(('red', 'green')))
plt.xlim(X1.min(), X1.max())
plt.ylim(X2.min(), X2.max())
for i, j in enumerate(np.unique(y_set)):
plt.scatter(X_set[y_set == j, 0], X_set[y_set == j, 1],
c = ListedColormap(('red','green'))(i),label=j)
plt.title('LOGISTIC(Test set)')
plt.xlabel('Exam 1 score')
plt.ylabel('Exam 2 score')
plt.legend()
plt.show()