在学习了前面讲到的将logistic函数用到分类问题中的文章后,你可能想知道为什么会冒出这样的模型,为什么这种模型是合理的。接下来,我们会答疑解惑,证明logistic回归和softmax回归只是广义线性模型(Generalized Linear Model,GLM)的一种特例,它们都是在广义线性模型的定义和指数族分布(Exponential Family Distribution)的基础上推导出来的。文章整体划分为三个部分,1)指数族分布;2)广义线性模型;3)由广义线性模型和指数族分布推导出logist回归和softmax回归用于分类问题的合理性。
Gauss分布、Bernoulli分布和泊松分布等更一般的形式就是指数族分布,属于指数族分布的各种概率分布有很多重要的共性。指数分布通常可写成如下形式: \begin{equation} p(y;\eta)=b(y)h(\eta)\exp\left(\eta^TT(y)\right) \end{equation} 其中\(\eta\)为自然参数(Natural Parameter);\(T(y)\)为充分统计量(Sufficient Statistics);\(h(\eta)\)为归一化常量(Normalization Constant),使得上式满足概率分布的条件,即\(p(y;\eta)\in[0,1]\)并且 \begin{equation} h(\eta)\int b(y)\exp\left(\eta^TT(y)\right)dy=1 \end{equation} 如果\(y\)为离散型变量,上式由积分形式变为求和形式即可。 下面,我们对Bernoulli分布和Gauss分布的数学表示形式进行变形,证明它们实际上都属于指数族分布:
在看到指数族分布的定义后,大家可能存在的一个问题就是\(T(y)\)为什么被称为充分统计量呢?下面来解释这个问题。我们将概率加和为1法则对应的公式左右两边同时对\(\eta\)求导,可得 \begin{equation} \begin{array}{l} \nabla h(\eta)\int b(y)\exp\left(\eta^TT(y)\right)dy\\ \quad+h(\eta)\int b(y)\exp\left(\eta^TT(y)\right)T(y)dy=0 \end{array} \end{equation} 对上式变形,并再次利用概率加和为1法则,得到下式 \begin{equation} \frac{\nabla h(\eta)}{h(\eta)}=-h(\eta)\int b(y)\exp\left(\eta^TT(y)\right)T(y)dy \end{equation} 我们用更精简的形式来表述: \begin{equation} \nabla \ln h(\eta)=-\mathbb{E}\left[T(y)\right] \end{equation} 假设我们现在有\(N\)个样本组成的数据集\(\mathcal{Y}=\{y_1,y_2,\cdots,y_N\}\),我们用最大似然的方法来估计参数\(\eta\),其对数似然函数形式如下: \begin{equation} \begin{array}{cl} \mathcal{L}&=\ln\left(\left(\prod_{i=1}^Nb(y_i)\right)h(\eta)^N\exp\left(\eta^T\sum_{i=1}^NT(y_i)\right)\right)\\ &=\sum_{i=1}^N\ln b(y_i)+N\ln h(\eta)+\eta^T\sum_{i=1}^NT(y_i) \end{array} \end{equation} 将\(\mathcal{L}\)对参数\(\eta\)求导并令其为0,得到 \begin{equation} \nabla\ln h(\eta_{ML})=-\frac{1}{N}\sum_{i=1}^NT(y_i) \end{equation} 根据上式我们可以求解出\(\eta_{ML}\)。我们可以看到最大似然估计仅仅通过\(\sum_iT(y_i)\)依赖样本点,因此被称为充分统计量。我们只需要存储充分统计量\(T(y)\)而不是数据本身。在Bernoulli分布中\(T(y)=y\),我们只需保存所有样本的加和\(\sum_iy_i\);在Gauss分布中,\(T(y)=(y,y^2)^T\),因此我们只要保持\(\sum_iy_i\)和\(\sum_iy_i^2\)即可。当\(N\rightarrow\infty\)时,上式的右侧就等价于\(\mathbb{E}\left[T(y)\right]\),\(\eta_{ML}\)此时也就等于\(\eta\)的真实值。实际上,该充分特性仅仅适用于贝叶斯推理(Bayesian Inference),详情请见《Pattern Recognition and Machine Learning》的第八章内容。
为了推导出一个广义线性模型用于分类或回归问题,我们得先做出三个假设:
在二分类问题中\(y\in\{0,1\}\),选择Bernoulli分布是很自然的。将Bernoulli分布写成指数族分布的形式,我们有\(\phi=1/\left(1+\exp(-\eta)\right)\) \begin{equation} \begin{array}{ll} h_{\theta}(x)&=E[y|x;\theta]\Longrightarrow\text{(满足假设2)}\\ &=\phi\Longleftarrow(y|x;\theta\sim Bernoulli(\phi))\\ &=1/\left(1+\exp(-\eta)\right)\Longrightarrow\text{(满足假设1)}\\ &=1/\left(1+\exp(-\theta^Tx)\right)\Longrightarrow\text{(满足假设3)} \end{array} \end{equation}
在多分类问题中\(y\in\{1,2,\cdots,k\}\),我们利用多项分布(Multinomial Distribution)对其进行建模。我们用\(k\)个参数\(\phi_1,\phi_2,\cdots,\phi_k\)分别表示\(y\)属于每一类的概率\(p(y=i;\phi)=\phi_i\)。但这些参数间并非完全独立的,因为存在关系式\(\sum_{i=1}^k\phi_i=1\)。因此,我们仅用\(k-1\)个参数\(\phi_1,\phi_2,\cdots,\phi_{k-1}\),\(\phi_k=p(y=k;\phi)=1-\sum_{i=1}^{k-1}\)在此并不是参数。为了把多项分布表示为指数族分布的形式,我们先定义\(T(y)\in\mathbb{R}^{k-1}\)如下: \begin{eqnarray} T(1)=\left[\begin{array}{c}1\\0\\\vdots \\0\end{array}\right], T(2)=\left[\begin{array}{c}0\\1\\\vdots \\0\end{array}\right], \cdots, T(k-1)=\left[\begin{array}{c}0\\0\\\vdots \\1\end{array}\right], T(k)=\left[\begin{array}{c}0\\0\\\vdots \\0\end{array}\right] \end{eqnarray} 其中,\(T(y)\)的第\(i\)个元素\(\left(T(y)\right)_i=1\{y=i\}\)。进一步,我们有\(\mathbb{E}\left[\left(T(y)\right)_i\right]=p(y=i)=\phi_i\)。 多项分布也是指数族分布的一员,推导如下: \begin{equation} \begin{array}{ll} p(y)&=\phi_1^{1\{y=1\}}\phi_2^{1\{y=2\}}\cdots\phi_k^{1\{y=k\}}\\ &=\phi_1^{1\{y=1\}}\phi_2^{1\{y=2\}}\cdots\phi_k^{1-\sum_{i=1}^{k-1}1\{y=i\}}\\ &=\phi_1^{\left(T(y)\right)_1}\phi_2^{\left(T(y)\right)_2}\cdots\phi_k^{1-\sum_{i=1}^{k-1}\left(T(y)\right)_i}\\ &=\exp\left(\left(T(y)\right)_1\log\phi_1+\cdots+\left(1-\sum_{i=1}^{k-1}\left(T(y)\right)_i\right)\ln\phi_k\right)\\ &=\exp\left(\left(T(y)\right)_1\log\frac{\phi_1}{\phi_k}+\left(T(y)\right)_2\log\frac{\phi_2}{\phi_k}+\cdots+\ln\phi_k\right)\\ &=b(y)h(\eta)\exp\left(\eta^TT(y)\right) \end{array} \end{equation} \begin{equation} b(y)=1 \end{equation} \begin{equation} \eta=\left[\begin{array}{c} \ln(\phi_1/\phi_k)\\ \ln(\phi_2/\phi_k)\\ \vdots\\ \ln(\phi_{k-1}/\phi_k)\\ \end{array}\right]\in\mathbb{R}^{k-1} \end{equation} \begin{equation} h(\eta)=\phi_k \end{equation} 为了方便,我们定义\(\eta_k=\ln(\phi_k/\phi_k)=0\)。由\(\eta\)的表达式可知 \begin{equation} \begin{array}{ll} \eta_i=\ln\frac{\phi_i}{\phi_k}&\Rightarrow\phi_k\exp(\eta_i)=\phi_i\\ &\Rightarrow\phi_k\sum_{i=1}^k\exp(\eta_i)=\sum_{i=1}^k\phi_i=1\\ &\Rightarrow\phi_i=\frac{\exp(\eta_i)}{\sum_{j=1}^k\exp(\eta_j)}=\frac{\exp(\eta_i)}{1+\sum_{j=1}^{k-1}\exp(\eta_j)} \end{array} \end{equation} 上式称为softmax函数,完成了自然参数\(\eta\)到多项分布参数\(\phi\)之间的映射。 在广义线性模型的第三个假设下,我们有\(\eta_i=\theta_i^Tx\)。其中,在\(x\)中引入了\(x_0=1\),以便把截距项考虑进来后以更紧凑的形式表述,那么\(x\in\mathbb{R}^{n+1}\),\(\theta_i\in\mathbb{R}^{n+1}\)。为了满足\(\eta_k=\theta_k^Tx=0\),我们定义\(\theta_k=0\)。如此一来,给定观测值\(x\),\(y\)在此模型下的条件概率形式如下: \begin{equation} p(y=i|x;\theta)=\phi_i=\frac{\exp(\eta_i)}{\sum_{j=1}^k\exp(\eta_j)}=\frac{\exp(\theta_i^Tx)}{1+\sum_{j=1}^{k-1}\exp(\theta^T_jx)} \end{equation} 那么模型的最终输出结果为 \begin{equation} h_{\theta}(x)=\mathbb{E}\left[T(y)|x;\theta\right]=\left[\begin{array}{c}\phi_1\\\phi_2\\\vdots\\\phi_{k-1}\end{array}\right]=\left[\begin{array}{c}\exp(\theta_1^Tx)/\sum_{i=1}^k\exp(\theta_i^Tx)\\\exp(\theta_2^Tx)/\sum_{i=1}^k\exp(\theta_i^Tx)\\\vdots\\\exp(\theta_{k-1}^Tx)/\sum_{i=1}^k\exp(\theta_i^Tx)\end{array}\right] \end{equation} 该模型称为softmax回归,可应用到多分类问题中,是logistic回归的推广。 给定由\(m\)个样本组成的训练集\(\{(x^{(i)},y^{(i)}),i=1,\cdots,m\}\),如果我们想用softmax回归实现分类任务,那么模型的参数\(\theta_i\)如何学习到呢?训练集上的对数似然函数如下: \begin{equation} \begin{array}{cl} \ell(\theta)&=\sum_{i=1}^m\log p(y^{(i)}|x^{(i)};\theta)\\ &=\sum_{i=1}^m\log\prod_{l=1}^k\left(\frac{\exp(\theta_{l}^Tx)}{\sum_{j=1}^k\exp(\theta_j^Tx)}\right)^{1\{y^{(i)}=l\}}\\ &=\sum_{i=1}^m\sum_{l=1}^k1\{y^{(i)}=l\}\left(\theta_l^Tx{(i)}-\log\sum_{j=1}^k\exp(\theta_j^Tx\right) \end{array} \end{equation} 我们仍然用最大似然的方法来估计参数\(\theta_i\),将对数似然函数\(\ell(\theta)\)对\(\theta_l\)求导得 \begin{equation} \frac{\partial\ell(\theta)}{\theta_l}=\sum_{i=1}^mx^{(i)}\left(1\{y^{(i)}=l\}-p(y^{(i)}=l|x^{(i)};\theta)\right) \end{equation}