目标: 预测值 = 真实值
方法: 最小化 dist(预测值, 真实值) ,这里的dist一般指的是均方误差,即二范数的平方。
目标: 最大化 benchmark, 例如,accuracy
方法1: 最小化 dist(p_theta(y|x), p_r(y|x))
方法2: 最小化 divergence(p_theta(y|x), p_r(y|x)), 其目的是为了让两个概率分布更加接近。
首先,假设一个二分类问题,大于0.5就是1,小于0.5就是0,。如果对accuracy进行maximize,那么会出现以下两个问题:
issues1: 如果我的accuracy值计算出来是0.6,其中有一个x对应的label是1,但是预测输出的概率为0.4,label值为0,此时更新通过更新梯度,使得输出的概率值为0.45,使其输出分布更加接近于真实的概率分布,但是其输出的label的值并没有改变,还是0,于是我的accuracy并不会变化,这就导致了梯度为0,于是后面就不会再进行更新,即,我的权重改变了,但是我的梯度并没有改变。
gradient = 0, if accuracy unchanged but weights changed
issues2:同上假设,假设误分类中输出有一个概率是0.499,那么只要权重的值有一点点的增加,概率变为0.501,大于0.5,于是输出的accuracy的值从0.6会跳变到0.8,会导致梯度的不连续。
gradient not continuous since the number of correct is not continuous
一方面,我们对分类问题输出的结果是一系列概率值,如果使用MSE(均方误差函数)对概率值,例如0.7,和0,1标签值,例如,标签值为1,计算loss,则有那么一点点回归的意思,即从0.7 --> 1。所以也可以称作回归。
另一方面,如果是cross entropy, 一般称为分类问题。
因此,一般对于二分类问题,如果最后的使用的是MSE进行回归那么可以说这是一个回归问题。对于多分类问题需要满足两个约束:1,每一个概率值必须是0~1之间;2,所有的概率值之和为1。
输出的logits 会经过softmax函数enlarge the larger ,使用的损失函数通常是cross entropy,所以这才是人们所说的真正意义上的分类问题。
熵越小,说明精喜度越高,也就是信息量越大,熵越大,就表示没有什么信息。
Entropy定义为:
E n t r o p y = − ∑ i P ( i ) l o g P ( i ) Entropy = -\sum_{i}P(i)logP(i) Entropy=−i∑P(i)logP(i)
Cross Entropy定义为:Entropy 衡量的是一个分布的不稳定度,但是交叉熵Cross Entropy 是衡量两个分布之间的不稳定度,定义为:
H ( p , q ) = − ∑ p ( x ) l o g q ( x ) H ( p , q ) = H ( p ) + D K L ( p ∣ q ) H(p,q)=-\sum p(x)\,log\, q(x)\\ H(p,q)=H(p)+D_{KL}\,(p|q) H(p,q)=−∑p(x)logq(x)H(p,q)=H(p)+DKL(p∣q)
其中D_KL是KL散度的意思。如果两个分布完全相似,那么他们的交叉熵几乎为0;反之如果只有不多的交集,那么他们的交叉熵就会非常大。
如果 P=Q 则,交叉熵就是熵,因为P,Q相等,所以他们的散度为0,所以交叉熵近似为熵。
对于one-hot encoding,H§其实就是0。因为根据定义,H§ = 1log(1) = 0.所以两个分布之间的交叉熵间接转化成了两个分布之间的散度
H ( p , q ) = − ∑ i = ( c a t , d o g ) P ( i ) l o g ( Q ( i ) ) = − P ( c a t ) l o g Q ( c a t ) − P ( d o g ) l o g Q ( d o g ) = − y l o g ( p ) − ( 1 − y ) l o g ( 1 − p ) H(p,q)=-\sum_{i=(cat,dog)}P(i)\,log\,(Q(i))\\ =-P(cat)log\,Q(cat) - P(dog)log\,Q(dog)\\ =-ylog(p)-(1-y)log(1-p) H(p,q)=−i=(cat,dog)∑P(i)log(Q(i))=−P(cat)logQ(cat)−P(dog)logQ(dog)=−ylog(p)−(1−y)log(1−p)
sigmoid + MSE: 会导致出现梯度弥散的现象;
cross entropy 的收敛速度会更快;
但有时候也会使用mse,因为它求导会非常的容易。一开始最好先试下最简单的方式,没有绝对的好。