AlexNet识别mnist手写数据集

AlexNet识别mnist手写数据集

  • 一、下载mnist手写数据集
  • 二、Alexnet实现
  • 三、结果
  • 四、总结

一、下载mnist手写数据集

链接:https://pan.baidu.com/s/1wne6e2rCivNdF4wS6MSBoQ
提取码:220w
复制这段内容后打开百度网盘手机App,操作更方便哦

二、Alexnet实现

用的是tensoflow1.x实现的AlexNet,简单的实现了一下
代码如下:

import tensorflow as tf
import numpy as np
import  matplotlib.pyplot as plt
import os

#导入mnist数据集

from tensorflow.examples.tutorials.mnist import  input_data
mnist=input_data.read_data_sets("D:/神经网络/Alexnet/MNIST_data/",one_hot=True)


#定义网络超参数
learning_rate=0.001
training_iters=1000  #训练次数
batch_size=128
display_step=10
#定义网络参数

n_input=784   #输入维度
n_classes=10  #标记维度  0-9
dropout=0.85   #dropout概率s


#占位符
x=tf.placeholder(tf.float32,[None,n_input])
y=tf.placeholder(tf.float32,[None,n_classes])
keep_prob=tf.placeholder(tf.float32)  

#构建模型

#卷积操作
def conv2d(x,W,b,strides=1):
    x=tf.nn.conv2d(x,W, strides=[1,strides,strides,1], padding='SAME')
    x=tf.nn.bias_add(x, b)
    return tf.nn.relu(x)
#池化操作
def maxpool2d(x,k=2):
    return tf.nn.max_pool(x,ksize=[1,k,k,1],strides=[1,k,k,1],padding='SAME')



weights={
     
    'wc1':tf.Variable(tf.random_normal([11,11,1,96])),
    'wc2':tf.Variable(tf.random_normal([5,5,96,256])),
    'wc3':tf.Variable(tf.random_normal([3,3,256,384])),
    'wc4':tf.Variable(tf.random_normal([3,3,384,384])),
    'wc5':tf.Variable(tf.random_normal([3,3,384,256])),
    'wd1':tf.Variable(tf.random_normal([4*4*256,4096])),
    'wd2':tf.Variable(tf.random_normal([4096,1024])),
    'out':tf.Variable(tf.random_normal([1024,n_classes]))
    
    }
biases={
     
        'bc1':tf.Variable(tf.random_normal([96])),
        'bc2':tf.Variable(tf.random_normal([256])),
        'bc3':tf.Variable(tf.random_normal([384])),
        'bc4':tf.Variable(tf.random_normal([384])),
        'bc5':tf.Variable(tf.random_normal([256])),
        'bd1':tf.Variable(tf.random_normal([4096])),
        'bd2':tf.Variable(tf.random_normal([1024])),
        'out':tf.Variable(tf.random_normal([n_classes]))
        }
def Alexnet(x,weights,biases,dropout):
    x=tf.reshape(x,shape=[-1,28,28,1])
    conv1=conv2d(x, weights['wc1'], biases['bc1'])
    conv1=maxpool2d(conv1,k=2)
    conv2=conv2d(conv1, weights['wc2'], biases['bc2'])
    conv2=maxpool2d(conv2,k=2)
    conv3=conv2d(conv2,weights['wc3'],biases['bc3'])
    conv4=conv2d(conv3,weights['wc4'],biases['bc4'])
    conv5=conv2d(conv4,weights['wc5'],biases['bc5'])
    conv5=maxpool2d(conv5,k=2)
    fc1=tf.reshape(conv5,[-1,weights['wd1'].get_shape().as_list()[0]])
    fc1 =tf.add(tf.matmul(fc1, weights['wd1']),biases['bd1'])
    fc1=tf.nn.relu(fc1)
    fc1=tf.nn.dropout(fc1,dropout)
    fc2=tf.reshape(fc1,[-1,weights['wd2'].get_shape().as_list()[0]])
    fc2 =tf.add(tf.matmul(fc2, weights['wd2']),biases['bd2'])
    fc2=tf.nn.relu(fc2)
    fc2=tf.nn.dropout(fc2,dropout)
    out=tf.add(tf.matmul(fc2,weights['out']),biases['out'])
    return out

pred=Alexnet(x, weights, biases, dropout)
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
correct_pred=tf.equal(tf.arg_max(pred,1),tf.arg_max(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    step=1
    while step<=training_iters:
        batch_x,batch_y=mnist.train.next_batch(batch_size)
        sess.run(optimizer,feed_dict={
     x:batch_x,y:batch_y})
        if step%display_step==0:
            loss,acc=sess.run([cost,accuracy],feed_dict={
     x:batch_x,y:batch_y,keep_prob:1.})
            print('Iter:'+str(step)+"Loss:"+"{:.6f}".format(loss)+"Accuracy:"+"{:.6f}".format(acc))
        step+=1
        
    print("Optimizer Finished!")
    print("Testing Accuracy",sess.run(accuracy,feed_dict={
     x:mnist.test.images[:256],y:mnist.test.labels[:256],keep_prob:1.}))

三、结果

训练了1000次得到的精度为0.96484375,没有仔细调参,调一下精度应该会更高,结果如下
AlexNet识别mnist手写数据集_第1张图片

四、总结

AlexNet可以说是开山之作,各位同学有必要仔细研究好好研究,尝试自己修改参数实现其他网络并且测试其他的数据集如猫狗数据集,我的另一篇博文中是使用VGGNet识别猫狗数据集,链接如下
VGGnet识别猫狗数据集(猫狗大战)
感兴趣的同学可以copy一下代码研究一下,此外由于写的匆忙没有详细说明,有不懂的可以评论我都会一一解答。

你可能感兴趣的:(神经网络,tensorflow,深度学习,神经网络,卷积神经网络,机器学习)