TensorFlow的逻辑回归实现

转载请注明出处:http://blog.csdn.net/gamer_gyt
博主微博:http://weibo.com/234654758
Github:https://github.com/thinkgamer
个人网站:http://thinkgamer.github.io


逻辑回归我们都知道是用来进行二分类处理的,里边经常用到的阶跃函数是海维塞得阶跃函数(Sigmoid函数)。

逻辑回归简介

线性回归能对连续值结果进行预测,而现实生活中常见的另外一类问题是,分类问题。最简单的情况是是与否的二分类问题。比如说医生需要判断病人是否生病,银行要判断一个人的信用程度是否达到可以给他发信用卡的程度,邮件收件箱要自动对邮件分类为正常邮件和垃圾邮件等等。

当然,我们最直接的想法是,既然能够用线性回归预测出连续值结果,那根据结果设定一个阈值是不是就可以解决这个问题了呢?然后在大多数情况下需要学习的分类数据并没有那么精准,这个时候阈值的设定就没卵用了,这时候就需要逻辑回归了,逻辑回归的核心思想就是通过对线性回归的计算结果进行一个映射,使之输出的结果为0~1之间的概率值。

这个时候就需要一个单位阶跃函数,常使用的就是Sigmoid函数。

![这里写图片描述](https://img-blog.csdn.net/20180423232907771?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0dhbWVyX2d5dA==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70)
求导之后为:
![这里写图片描述](https://img-blog.csdn.net/20180423232953557?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0dhbWVyX2d5dA==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70)
对应的图像为:
![这里写图片描述](https://img-blog.csdn.net/20180423233017375?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0dhbWVyX2d5dA==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70)

更多关于逻辑回归的介绍可以参考之前的两篇文章:

  • 《机器学习实战》Logistic回归算法(1)
  • 《机器学习实战》Logisic回归算法(2)之从疝气病症预测病马的死亡率

关于Softmax Regression

参考之前转载的一篇文章:Softmax Regression

TF中基于LR的多分类实现

# coding: utf-8
'''
create by: Thinkgamer
create time: 2018/04/22
desc: 使用tensorflow创建逻辑回归模型 ,分类
'''
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# 加载mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
print("load finish")

mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
print(type(mnist))

train_img = mnist.train.images
train_label = mnist.train.labels
print("训练集类型:",type(train_img))
print("训练集维度:",train_img.shape)

test_img = mnist.test.images
test_label = mnist.test.labels
print("测试集类型:",type(test_img))
print("测试集维度:",test_img.shape)
print(test_label[0])
# [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] 表示数字7

输出为:

load finish
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

训练集类型: 
训练集维度: (55000, 784)
测试集类型: 
测试集维度: (10000, 784)
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])

W = tf.Variable(tf.zeros([784,10]))  # 表示权重,784个维度,10种类别
b = tf.Variable(tf.zeros([10]))      # 表示 是10分类,这里选用0值初始化,一般采用高斯初始化

# 表示预测值结果
actv = tf.nn.softmax(tf.matmul(x,W)+b)

# 构造损失函数
# 损失函数 -log p, 各个维度损失函数求和之后,求均值
cost= tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))    

# Optimizer 指定优化器 梯度下降
learning_rate = 0.1
optm = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
sess = tf.InteractiveSession()
# 函数学习 tf.argmax 返回最大值的索引
arr = np.array([
    [1,2,3,4,5],
    [6,7,8,9,10]
])
print(tf.shape(arr).eval())
print(tf.argmax(arr,0).eval())  #eval(session=sess))  0 是按列 1 是按行
print(tf.argmax(arr,1).eval()) 
# 预测
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))

# 准确率
accr = tf.reduce_mean(tf.cast(pred,"float"))

# init
init = tf.global_variables_initializer()
training_epochs = 50  # 样本迭代次数
batch_size = 100 # 每次迭代使用的样本
display_step = 50

sess = tf.Session()
sess.run(init)

# MINI-BATCH LEARENING
for epoch in range(training_epochs):
    avg_cost = 0.
    number_batch = int(mnist.train.num_examples/batch_size)
    for i in range(number_batch):
        batch_xs,batch_ys = mnist.train.next_batch(batch_size)
        sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})
        feeds = {x: batch_xs,y: batch_ys}
        avg_cost += sess.run(cost,feed_dict=feeds)/number_batch
    # DISPLAY
    if epoch % display_step ==0 :
        feeds_train = {x:batch_xs,y:batch_ys}
        feeds_test = {x: mnist.test.images,y:mnist.test.labels}
        train_acc = sess.run(accr,feed_dict=feeds_train)
        test_acc = sess.run(accr,feed_dict=feeds_test)
        print(("Epoch: %03d/%03d cost:%.9f train_acc: %.3f test_acc: %.3f") % (epoch,training_epochs,avg_cost,train_acc,test_acc))
print("Done")

TensorFlow的逻辑回归实现_第1张图片
打开微信扫一扫,关注微信公众号【搜索与推荐Wiki】

你可能感兴趣的:(机器学习(Python),机器学习)