逻辑回归中的sigmoid函数:
g ( θ 0 + θ 1 X 1 + ⋯ θ n X n ) = h ( θ X ) = p = e θ X 1 + e θ X g(θ_0+θ_1X_1+\cdotsθ_nX_n)=h(θX)=p=\frac{e^{θX}}{1+e^{θX}} g(θ0+θ1X1+⋯θnXn)=h(θX)=p=1+eθXeθX
也可写成(分子分母同时除以eθX):
g ( θ 0 + θ 1 X 1 + ⋯ θ n X n ) = h ( θ X ) = p = 1 1 + e − θ X g(θ_0+θ_1X_1+\cdotsθ_nX_n)=h(θX)=p=\frac{1}{1+e^{-θX}} g(θ0+θ1X1+⋯θnXn)=h(θX)=p=1+e−θX1
得到的sigmoid函数图像为:
可知,当θX>=0,即θ0+θ1X1+…+θnXn>=0时,p>=0.5,预测y=1
当θX<0,即θ0+θ1X1+…+θnX<0时,p<0.5,预测y=0
为什么不使用像线性回归那样的代价函数?
如图,因为所得到的代价函数为非凸函数,会有很多个局部最优解,使用梯度下降时不一定能够收敛到全局最小值。
sigmoid函数:
h θ ( x ) = 1 1 + e − θ x h_\theta(x)=\frac{1}{1+e^{-\theta x}} hθ(x)=1+e−θx1
使用梯度下降求θ:
∂ J ( θ ) ∂ θ j = ( h θ ( x ) − y ) x j \frac{\partial J(\theta)}{\partial\theta_j}= (h_\theta(x)-y)x_j ∂θj∂J(θ)=(hθ(x)−y)xj
可得:
θ j : = θ j − α ∑ i = 1 m ( h θ ( x ) ( i ) − y ( i ) ) x j ( i ) \theta_j:=\theta_j-\alpha\sum_{i=1}^m(h_\theta(x)^{(i)}-y^{(i)})x_j^{(i)} θj:=θj−αi=1∑m(hθ(x)(i)−y(i))xj(i)
其中i表示第几个样本,j表示第几个特征值
思想:将多个类别分成独立的二元问题,然后分别求出各个分类的代价函数,再将测试集的数据代入所获得的结果中,取结果最大的预测概率,就能判断出这个数据属于哪一个类别。
使用逻辑回归将鸢尾花数据进行分类。鸢尾花数据为3类,属于多元分类问题。可以将其分解成多个二元问题,即对于每个类别,进行模型训练得到模型,然后对于测试集中的每个数据,使用得到的参数对其进行计算,得到的最大值即为这个数据属于的类别。
"""
@Author: sanshui
@Time:2021/11/15 15:15
@Software: PyCharm
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
def get_data():
"""
获取数据集
:return:
"""
iris = load_iris()
return iris.data, iris.target
def split_data(data, target):
"""
划分数据集
:param data:
:param target:
:return:
"""
data = np.array(data)
target = np.array(target)
target = np.reshape(target, (target.shape[0], 1))
# 正则化数据,防止数据大小本身对结果造成影响
sd = StandardScaler()
data = sd.fit_transform(data)
# 拼接特征值与类别
dataset = np.hstack((data, target))
n = dataset.shape[0]
# 打乱数据
np.random.shuffle(dataset)
# 划分数据集,返回训练集与测试集
train = dataset[:int(0.7 * n), :]
test = dataset[int(0.7 * n):, :]
return train, test
def sigmoid(z):
"""
sigmoid函数
:param z:
:return:
"""
return 1 / (1 + np.exp(-z))
def draw_sigmoid():
"""
画出sigmoid函数
:return:
"""
fig, ax = plt.subplots()
x_data = np.arange(-10, 10, 0.1)
ax.plot(x_data, sigmoid(x_data))
plt.show()
def calCost(dataset, theta):
"""
计算代价函数
:param dataset:
:param theta:
:return:
"""
x = dataset[:, :-1]
y = dataset[:, -1:]
z = x @ theta.T
# 训练数据个数,或者用m = y.shape[1]
m = y.size
para1 = np.multiply(-y, np.log(sigmoid(z)))
para2 = np.multiply((1 - y), np.log(1 - sigmoid(z)))
# 代价函数Y
J = 1 / m * np.sum(para1 - para2)
return J
def gradient(dataset, theta, iters, alpha):
"""
梯度下降
:param dataset:
:param theta:
:param iters:
:param alpha:
:return:
"""
# 存放每次梯度下降后的损失值
x = dataset[:, :-1]
y = dataset[:, -1:]
for i in range(iters):
h_x = sigmoid(x @ theta.T)
theta = theta - alpha / len(x) * (h_x - y).T @ x
return theta
def get_per_classify_data(data, i):
"""
返回第i类的数据
:param data:数据集
:param i:类别
:return:
"""
return data[data[:, -1] == i]
def get_final_theta(data, i, theta, iters, alpha):
"""
获取梯度下降后的theta值
:param data:
:param i:
:param theta:
:param iters:
:param alpha:
:return:
"""
dataset = get_per_classify_data(data, i)
return gradient(dataset, theta, iters, alpha)
def predict(dataset, theta_list):
"""
预测结果
:param dataset:
:param theta_list:
:return:
"""
x = dataset[:, :-1]
per_theta_list = [i[0] for i in theta_list]
per_theta_list = np.array(per_theta_list)
per_prob = sigmoid(np.dot(x, per_theta_list.T))
# 返回每行最大值所在的索引,即概率最大的类别
return np.argmax(per_prob, axis=1)
if __name__ == '__main__':
plt.rcParams['font.sans-serif'] = 'SimHei' # 黑体
plt.rcParams['axes.unicode_minus'] = False # 显示负号
data, target = get_data()
train, test = split_data(data, target)
draw_sigmoid()
iters = 1000 # 迭代次数
alpha = 0.5 # 学习率
theta_list = []
for i in range(data.shape[1]):
theta = np.zeros((1, data.shape[1]))
theta_list.append(theta)
final_theta_list = []
target_list = list(set(target))
for i in target_list:
theta_i = get_final_theta(train, i, theta_list[target_list.index(i)], iters, alpha)
final_theta_list.append(theta_i)
y_predict = predict(test, final_theta_list)
# 查看预测准确度
print(classification_report(y_predict, test[:, -1]))