深度学习入门——MNIST手写数字识别之神经网络的实现

  • 神经网络的输入层有784个神经元(因为图像像素28*28)
  • 输出层有10个神经元(数字0~9有十个类别)
  • 第一个隐藏层:50个神经元**(任意设置的值)**
  • 第二个隐藏层:100个神经元**(任意设置的值)**

neuralnet_mnist.py(假设学习过程已经结束)

# coding: utf-8
import pickle
import sys, os
sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist


# --------------------------------激活函数定义---------------------------------------------------------------------------


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def softmax(a):
    """解决softmax函数的溢出问题,利用c(输入的最大值),softmax函数减去这个最大值保证数据不溢出,softmax函数运算时加上或者
    减去某个常数并不会改变运算的结果"""
    c = np.max(a)
    exp_a = np.exp(a - c) # 溢出对策
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a
    return y

# --------------------------神经网络相关函数----------------------------------------------------------------------------
# 假设学习过程已经结束了 有了sample_weight.pkl

def get_data():
    """获取mnist数据
    (训练图像,训练标签),(测试图像,测试标签)"""
    # normalize正则化,flatten一维化,one_hot化
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    """初始化网络"""
    with open("sample_weight.pkl", 'rb') as f: # sample_weight.pkl文件中保存了已经学习后的权重和偏置的参数
        network = pickle.load(f)

    return network


def predict(network, x):
    w1, w2, w3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)

    a2 = np.dot(z1, w2) + b2
    z2 = sigmoid(a2)

    a3 = np.dot(z2, w3) + b3
    y = softmax(a3)

    return y


x, t = get_data() # 获取MNIST数据
network = init_network() # 生成网络

accuracy_cnt = 0
for i in range(len(x)):
    """取出保存在图像中的数据"""
    y = predict(network, x[i])

    p = np.argmax(y) # 获取概率最高的元素的索引

    if p == t[i]: # 预测值与标签值 看是否一样
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt / len(x)))) # 识别精度

预处理:
对神经网络的输入数据进行某种限定,例如利用数据的整体均值或者标准差,移动数据。使数据以0为中心分布。

  • 正规化:将数据的延展控制在某个范围内
  • 数据白化(whitening):将数据整体的分布均匀化的方法

你可能感兴趣的:(深度学习入门,python)