tensorflow实现lenet5

lenet5 结构 及 pytorch、tensorflow、keras(tf)、paddle实现 mnist手写数字识别​​​​​​​

环境

python3.6,tensorflow-gpu 1.12.0

代码

# -*- coding: utf-8 -*- 
# @Time : 2020/1/18 14:59 
# @Author : Zhao HL
# @File : lenet5_tf.py 
import sys,cv2,os
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tf.logging.set_verbosity(tf.logging.ERROR)

# region parameters
# region paths
Data_path = "./data/"
TestData_path = Data_path + 'pic/'
Model_path = 'model/'
Model_file_tf = "model/lenet5_tf.ckpt"
Model_file_keras = "model/lenet5_keras.h5"
Model_file_torch = "model/lenet5_torch.pth"
Model_file_paddle = "model/lenet5_paddle.model"
# endregion

# region image parameter
Img_size = 28
Img_chs = 1
Label_size = 1
Labels_classes = 10
# endregion

# region net parameter
Conv1_kernel_size = 5
Conv1_chs = 6
Conv2_kernel_size = 5
Conv2_chs = 16
Conv3_kernel_size = 5
Conv3_chs = 120
Flatten_size = 120
Fc1_size = 84
Fc2_size = Labels_classes
# endregion

# region hpyerparameter
Learning_rate = 1e-3
Batch_size = 64
Buffer_size = 256
Infer_size = 1
Epochs = 6
Train_num = 60000
Train_batch_num = Train_num // Batch_size
Val_num = 10000
Val_batch_num = Val_num // Batch_size
# endregion

# endregion


class Lenet5:
    def __init__(self,structShow=False):
        self.structShow=structShow
        self.img_src = tf.placeholder(tf.float32, [None, Img_size * Img_size])
        self.label = tf.placeholder(tf.float32, [None, Labels_classes])
        self.image = tf.reshape(self.img_src, [-1, Img_size, Img_size, Img_chs])
        self.predict = self.get_lenet5()


    def get_w(self, shape):
        return tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='w')

    def get_b(self, shape):
        return tf.Variable(tf.zeros(shape), name='b')

    def get_lenet5(self):
        with tf.name_scope('conv1'):
            conv1_w = self.get_w([Conv1_kernel_size, Conv1_kernel_size, Img_chs, Conv1_chs])
            conv1_b = self.get_b([Conv1_chs])
            conv1 = tf.nn.conv2d(self.image, conv1_w, strides=[1, 1, 1, 1], padding='SAME')
            relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_b))

        with tf.name_scope('pool1'):
            pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool1')

        with tf.name_scope('conv2'):
            conv2_w = self.get_w([Conv2_kernel_size, Conv2_kernel_size, Conv1_chs, Conv2_chs])
            conv2_b = self.get_b([Conv2_chs])
            conv2 = tf.nn.conv2d(pool1, conv2_w, strides=[1, 1, 1, 1], padding='VALID')
            relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_b))

        with tf.name_scope('pool2'):
            pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2')

        with tf.name_scope('conv3'):
            conv3_w = self.get_w([Conv3_kernel_size, Conv3_kernel_size, Conv2_chs, Conv3_chs])
            conv3_b = self.get_b([Conv3_chs])
            conv3 = tf.nn.conv2d(pool2, conv3_w, strides=[1, 1, 1, 1], padding='VALID')
            relu3 = tf.nn.relu(tf.nn.bias_add(conv3, conv3_b))

        with tf.name_scope('fc1'):
            relu3_reshape = tf.reshape(relu3, [-1, Flatten_size])
            fc1_w = self.get_w([Flatten_size, Fc1_size])
            fc1_b = self.get_b([Fc1_size])
            fc1 = tf.matmul(relu3_reshape, fc1_w)
            relu4 = tf.nn.relu(fc1 + fc1_b)

        with tf.name_scope('fc2'):
            fc2_w = self.get_w([Fc1_size, Fc2_size])
            fc2_b = self.get_b([Fc2_size])
            fc2 = tf.matmul(relu4, fc2_w)
            output = tf.nn.softmax(fc2+ fc2_b)

        if self.structShow:
            print(relu1.name,relu1.shape)
            print(pool1.name,pool1.shape)
            print(relu2.name,relu3.shape)
            print(pool2.name,pool2.shape)
            print(relu3.name,relu3.shape)
            print(relu4.name,relu4.shape)
            print(output.name,output.shape)
        return output

def train():
    mnist_data = input_data.read_data_sets(Data_path, one_hot=True)
    net = Lenet5(structShow=True)
    image_src,image,label,predict = net.img_src,net.image,net.label,net.predict

    loss = tf.reduce_mean(-tf.reduce_sum(label * tf.log(predict + 1e-10), reduction_indices=1))
    run_step = tf.train.AdamOptimizer(Learning_rate).minimize(loss)

    correct = tf.equal(tf.argmax(predict, 1), tf.argmax(label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        model = tf.train.get_checkpoint_state(Model_file_tf)
        if model and model.model_checkpoint_path:
            saver.restore(sess, model.model_checkpoint_path)

        best_loss = float("inf")
        best_loss_epoch = 0

        for epoch in range(Epochs):
            print('Epoch %d/%d:' % (epoch + 1, Epochs))
            train_sum_loss = 0
            train_sum_acc = 0
            val_sum_loss = 0
            val_sum_acc = 0
            for batch in range(Train_batch_num):
                smp_batch = mnist_data.train.next_batch(Batch_size)
                train_acc, train_loss, _ = sess.run(
                    [accuracy, loss, run_step], feed_dict={image_src: smp_batch[0], label: smp_batch[1]})
                process_show(batch + 1,Train_batch_num, train_acc,train_loss, prefix='train:')
                train_sum_acc += train_acc
                train_sum_loss += train_loss

            for batch in range(Val_batch_num):
                smp_batch = mnist_data.validation.next_batch(Batch_size)
                val_acc, val_loss = sess.run(
                    [accuracy, loss], feed_dict={image_src: smp_batch[0], label: smp_batch[1]})
                val_sum_acc += val_acc
                val_sum_loss += val_loss
                process_show(batch + 1, Val_batch_num, val_acc, val_loss, prefix='val:')
            train_sum_loss /= Train_batch_num
            train_sum_acc /= Train_batch_num
            val_sum_loss /= Val_batch_num
            val_sum_acc /= Val_batch_num
            print('average summary:\ntrain acc %.4f, loss %.4f ; val acc %.4f, loss %.4f'
                  % (train_sum_acc, train_sum_loss, val_sum_acc, val_sum_loss))

            if val_sum_loss < best_loss:
                print('val_loss improve from %.4f to %.4f, model save to %s ! \n' % (
                best_loss, val_sum_loss, Model_file_torch))
                best_loss = val_sum_loss
                best_loss_epoch = epoch + 1
                saver.save(sess=sess, save_path=Model_file_tf)
            else:
                print('val_loss do not improve from %.4f \n' % (best_loss))
        print('best loss %.4f at epoch %d \n' % (best_loss, best_loss_epoch))

def inference(infer_path=TestData_path,model_path = Model_file_tf):
    '''
    推理代码
    :param infer_path: 推理数据
    :param model_path: 模型
    :return:
    '''
    tf.reset_default_graph()
    net = Lenet5()

    image,predict = net.image,net.predict
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,model_path)
        print('get model from', model_path)

        for image_name in os.listdir(infer_path):
            img = load_image(infer_path+image_name)
            result = sess.run([predict],feed_dict={image: img})
            pre = np.argmax(np.array(result))
            print("{} predict result {}".format(image_name, pre))




def process_show(num, nums, train_acc, train_loss, prefix='', suffix=''):
    rate = num / nums
    ratenum = int(round(rate, 2) * 100)
    bar = '\r%s batch %3d/%d:train accuracy %.4f, train loss %00.4f [%s%s]%.1f%% %s; ' % (
        prefix, num, nums, train_acc, train_loss, '#' * (ratenum//2), '_' * (50 - ratenum//2), ratenum, suffix)
    sys.stdout.write(bar)
    sys.stdout.flush()
    if num >= nums:
        print()

def load_image(file):
    img = cv2.imread(file,cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img,(Img_size, Img_size))
    img = np.array(img).reshape(Infer_size,Img_size, Img_size,Img_chs).astype(np.float32)/255.0
    return img

if __name__ == '__main__':
    pass
    train()
    inference()

 

你可能感兴趣的:(TensorFlow,DL-Code,lenet5,tensorflow)