TensorFlow学习笔记(三):TensorFlow实现逻辑回归模型

一、Mnist数据集的介绍与获取

1.Mnist数据集的介绍

简介

60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)
每一个MNIST数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签。我们把这些图片设为“xs”,把这些标签设为“ys”。训练数据集和测试数据集都包含xs和ys,比如训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels。

官网:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

下载和读取数据集

import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

运行

from tensorflow.examples.tutorials.mnist import input_data

# one_hot 独热编码,也叫一位有效编码。在任意时候只有一位为1,其他位都是0
mnist = input_data.read_data_sets("data/", one_hot=True)

train_images = mnist.train.images
train_labels = mnist.train.labels
test_images = mnist.test.images
test_labels = mnist.test.labels

print("train_images_shape:", train_images.shape)
print("train_labels_shape:", train_labels.shape)
print("test_images_shape:", test_images.shape)
print("test_labels_shape:", test_labels.shape)
print("train_images:", train_images[0]) # 获取55000张里第一张
print("train_images_length:",len(train_images[0]))
print("train_labels:", train_labels[0])

运行结果

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
train_images_shape: (55000, 784) # 一共55000张训练数据 784=28*28像素点
train_labels_shape: (55000, 10) # 10列 0到9十个数字
test_images_shape: (10000, 784) # 10000张测试数据 10000行 784列
test_labels_shape: (10000, 10)
train_images: [ 0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.38039219  0.37647063
  0.3019608   0.46274513  0.2392157   0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.35294119  0.5411765
  0.92156869  0.92156869  0.92156869  0.92156869  0.92156869  0.92156869
  0.98431379  0.98431379  0.97254908  0.99607849  0.96078438  0.92156869
  0.74509805  0.08235294  0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.54901963  0.98431379  0.99607849  0.99607849  0.99607849  0.99607849
  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849
  0.99607849  0.99607849  0.99607849  0.99607849  0.74117649  0.09019608
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.88627458  0.99607849  0.81568635
  0.78039223  0.78039223  0.78039223  0.78039223  0.54509807  0.2392157
  0.2392157   0.2392157   0.2392157   0.2392157   0.50196081  0.8705883
  0.99607849  0.99607849  0.74117649  0.08235294  0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.14901961  0.32156864  0.0509804   0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.13333334  0.83529419  0.99607849  0.99607849  0.45098042  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.32941177  0.99607849  0.99607849  0.91764712  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.32941177  0.99607849  0.99607849  0.91764712  0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.41568631  0.6156863   0.99607849  0.99607849  0.95294124  0.20000002
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.09803922  0.45882356  0.89411771
  0.89411771  0.89411771  0.99215692  0.99607849  0.99607849  0.99607849
  0.99607849  0.94117653  0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.26666668  0.4666667   0.86274517
  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849  0.99607849
  0.99607849  0.99607849  0.99607849  0.55686277  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.14509805  0.73333335  0.99215692
  0.99607849  0.99607849  0.99607849  0.87450987  0.80784321  0.80784321
  0.29411766  0.26666668  0.84313732  0.99607849  0.99607849  0.45882356
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.44313729
  0.8588236   0.99607849  0.94901967  0.89019614  0.45098042  0.34901962
  0.12156864  0.          0.          0.          0.          0.7843138
  0.99607849  0.9450981   0.16078432  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.66274512  0.99607849  0.6901961   0.24313727  0.          0.
  0.          0.          0.          0.          0.          0.18823531
  0.90588242  0.99607849  0.91764712  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.07058824  0.48627454  0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.32941177  0.99607849  0.99607849  0.65098041  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.54509807  0.99607849  0.9333334   0.22352943  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.82352948  0.98039222  0.99607849  0.65882355  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.94901967  0.99607849  0.93725497  0.22352943  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.34901962  0.98431379  0.9450981   0.33725491  0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.01960784  0.80784321  0.96470594  0.6156863   0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.01568628  0.45882356  0.27058825  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.        ]  #一对中括号是一行
train_images_length: 784
train_labels: [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]#one_hot 图片是7

二、softmax函数及TensorFlow基本语法

逻辑回归是softtmax的一个特例,逻辑回归是处理二分类的问题,softmax是处理多分类的问题。
TensorFlow学习笔记(三):TensorFlow实现逻辑回归模型_第1张图片

import tensorflow as tf
import numpy as np

# 占位符,适用于不知道具体参数的时候
x = tf.placeholder(tf.float32, shape=(4, 4)) #用4行4列类型为float的矩阵来填充x
y = tf.add(x, x) # x+x
# [1,  32, 44, 56]
# [89, 12, 90, 33]
# [35, 69, 1,  10]
argmax_paramter = tf.Variable([[1, 32, 44, 56], [89, 12, 90, 33], [35, 69, 1, 10]]) #tf.Variable创建一个变量

# 最大列索引
argmax_0 = tf.argmax(argmax_paramter, 0) #argmax求最大列的索引 0最大列 1最大行
# 最大行索引
argmax_1 = tf.argmax(argmax_paramter, 1)

# 平均数
reduce_0 = tf.reduce_mean(argmax_paramter, reduction_indices=0)#求平均数 reduction_indices可去掉 和上面写法一样 列
reduce_1 = tf.reduce_mean(argmax_paramter, reduction_indices=1)

# 相等
equal_0 = tf.equal(1,2) #求两个数是否相等 T or F
equal_1 = tf.equal(2,2)

# 类型转换
cast_0 = tf.cast(equal_0,tf.int32)# 转换的放前,模板放后
cast_1 = tf.cast(equal_1,tf.float32)

with tf.Session() as sess:
    init = tf.global_variables_initializer();
    sess.run(init)

    rand_array = np.random.rand(4, 4) 
    print(sess.run(y, feed_dict={x: rand_array})) # 把随机生成的矩阵赋值给x,这样不会报错

    print("argmax_0:", sess.run(argmax_0))
    print("argmax_1:", sess.run(argmax_1))
    print("reduce_0:", sess.run(reduce_0))
    print("reduce_1:", sess.run(reduce_1))
    print("equal_0:", sess.run(equal_0))
    print("equal_1:", sess.run(equal_1))
    print("cast_0:", sess.run(cast_0))
    print("cast_1:", sess.run(cast_1))

运行结果

[[ 0.53915852  0.77000558  1.52799249  1.89688838]
 [ 1.51494813  0.71548402  0.76938975  0.32049751]
 [ 0.48544651  0.25060761  1.42048228  0.97588754]
 [ 0.05079511  1.94627345  1.63919222  1.34640694]]
argmax_0: [1 2 1 0]
argmax_1: [3 2 1]
reduce_0: [41 37 45 33]
reduce_1: [33 56 28]
equal_0: False
equal_1: True
cast_0: 0
cast_1: 1.0

三、代码实现

import tensorflow as tf

# 导入数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data/", one_hot=True)

# 变量
batch_size = 100

#训练的x(image),y(label)
# x = tf.Variable()
# y = tf.Variable()
x = tf.placeholder(tf.float32, [None, 784]) #image赋值给x label赋值给y  784=28*28
y = tf.placeholder(tf.float32, [None, 10]) #10(0-9)   实际开发中 训练数据庞大 用variable很占内存 所以引入placeholder占位容器
#placeholder 装多少批 每批装载多少个 自己定义 每批装100  None事先无法确定

# 模型权重
#[55000,784] * W = [55000,10]
W = tf.Variable(tf.zeros([784, 10])) #懵W和b  W*x W的行=x的列=784 最终得lable10列的矩阵 W的列=x的行=10
b = tf.Variable(tf.zeros([10])) #矩阵相加 行列相同

# 用softmax构建逻辑回归模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) #TensorFlow在nn模块已经提供好了softmax的API
#pred 预测的label

# 损失函数(交叉熵)
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), 1))

# 低度下降
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)

# 初始化所有变量
init = tf.global_variables_initializer()

# 加载session图
with tf.Session() as sess:
    sess.run(init) #运行初始化变量

    # 开始训练
    for epoch in range(25): #25自己指定
        avg_cost = 0.

        total_batch = int(mnist.train.num_examples/batch_size) #batch_size分页
        for i in range(total_batch): #一页一页的训练
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(optimizer, {x: batch_xs,y: batch_ys})
            #计算损失平均值 当前页所有训练集代价函数的平均值
            avg_cost += sess.run(cost,{x: batch_xs,y: batch_ys}) / total_batch
        if (epoch+1) % 5 == 0:
            print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))

    print("运行完成")

    # 测试求正确率
    correct = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    print("正确率:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
Epoch: 0005 cost= 0.463807613
Epoch: 0010 cost= 0.390929620
Epoch: 0015 cost= 0.361378004
Epoch: 0020 cost= 0.344156513
Epoch: 0025 cost= 0.332462255
运行完成
正确率: 0.9132

你可能感兴趣的:(TensorFlow学习笔记)