Pytorch学习笔记5:Logistic回归(分类问题)

分类问题

之前举例中学习时间与学习成绩的问题,是回归问题,也就是输入和输出数值之间的关系

x(hours) y(points)
1 2
2 4
3 6
4 ?

但是如果经y的值标记为是否合格,此问题将变成一个二分类问题:
 

x(hours) y(pass or fail)

1

fail(0)
2 fail(0)
3 pass(1)
4 ?

实际上是计算在4工时下是否合格的概率,即对概率的计算与比较,而非类别之间的数值比较。

逻辑回归:

二分类问题是非0即1的问题,由于隐藏条件的限制:P(\widehat{y} = 1)+P(\widehat{y} = 0)=1,对于二分类问题结果的预测,仅需要计算在0或1的条件下即可得到答案。

逻辑函数:

在原先的回归问题中,所利用的模型为 \widehat{y}=wx+b,此时的\widehat{y}∈R,但当问题为分类问题时,所求结果的值域应当发生改变,变为一个概率即\widehat{y}​​​​​​​∈[0,1],因此,需要引入逻辑函数(sigmoid)来实现:

本函数原名为logistics函数,属于sigmod类函数,由于其特性优异,代码中的sigmod函数就指的是本函数。其函数图像为:

                                                    Pytorch学习笔记5:Logistic回归(分类问题)_第1张图片

特点:

  1. 函数值在0到1之间变化明显(导数大)
  2. 在趋近于0和1处函数逐渐平滑(导数小)
  3. 函数为饱和函数
  4. 单调增函数

其他类型的sigmoid函数:

                                          Pytorch学习笔记5:Logistic回归(分类问题)_第2张图片

模型的变化:

模型结构变化:

原线性回归模型变为二分类模型

                    Pytorch学习笔记5:Logistic回归(分类问题)_第3张图片

loss变化:

原先是计算两个标量数值间的差距,也就是数轴上的距离。

现在为了计算两个概率之间的差异,需要利用到交叉熵的理论。

                          Pytorch学习笔记5:Logistic回归(分类问题)_第4张图片

mini-batch的loss示例:

import torch.nn.functional as F
import torch

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0.0],[0.0],[1.0]])

#改用LogisticRegressionModel 同样继承于Module
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1,1)

    def forward(self, x):
        #对原先的linear结果进行sigmod激活
        y_pred = F.sigmoid(self.linear(x))
        return y_pred
model = LogisticRegressionModel()

#构造的criterion对象所接受的参数为(y',y) 改用BCE
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)

print('y_pred = ',y_test.data)

                                                     Pytorch学习笔记5:Logistic回归(分类问题)_第5张图片

 

 

 

 

 

 

 

 

 

 

你可能感兴趣的:(Pytorch学习笔记5:Logistic回归(分类问题))