神大学习入门——MNIST手写字识别之神经网络实现——批处理

神大学习入门——MNIST手写字识别之神经网络实现——批处理_第1张图片

神大学习入门——MNIST手写字识别之神经网络实现——批处理_第2张图片

neural_network.py(加上批量数据)


# coding: utf-8
import pickle
import sys, os
sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist


# --------------------------------激活函数定义---------------------------------------------------------------------------


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def softmax(a):
    """解决softmax函数的溢出问题,利用c(输入的最大值),softmax函数减去这个最大值保证数据不溢出,softmax函数运算时加上或者
    减去某个常数并不会改变运算的结果"""
    c = np.max(a)
    exp_a = np.exp(a - c) # 溢出对策
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a
    return y

# --------------------------神经网络相关函数----------------------------------------------------------------------------
# 假设学习过程已经结束了 有了sample_weight.pkl

def get_data():
    """获取mnist数据
    (训练图像,训练标签),(测试图像,测试标签)"""
    # normalize正则化,flatten一维化,one_hot化
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    """初始化网络"""
    with open("sample_weight.pkl", 'rb') as f: # sample_weight.pkl文件中保存了已经学习后的权重和偏置的参数
        network = pickle.load(f)

    return network


def predict(network, x):
    w1, w2, w3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)

    a2 = np.dot(z1, w2) + b2
    z2 = sigmoid(a2)

    a3 = np.dot(z2, w3) + b3
    y = softmax(a3)

    return y


x, t = get_data() # 获取MNIST数据
network = init_network() # 生成网络

batch_size = 100 # 批数量

accuracy_cnt = 0
for i in range(len(x)):
    """取出保存在图像中的数据"""
    x_batch = x[i:i+batch_size] # 划分为每一个区间,从每个区间中提取区数据
    y_batch = predict(network, x_batch)

    p = np.argmax(y_batch, axis=1) # 获取概率最高的元素的索引,axis=1表示沿着第一维方向找到最大值(每行最大值),axis=0则表示
    # 按照列找最大值

    accuracy_cnt += np.sum(p == t[i:i+batch_size]) # 判断相等

print("Accuracy:" + str(float(accuracy_cnt / len(x)))) # 识别精度

你可能感兴趣的:(深度学习入门)