TF2.0实现Lenet_5网络

# 小型CNN模型搭建和测试
import picture_read
import numpy as np
import tensorflow as tf
import time

def data_processing(train_document1,train_document2,test_document1,test_document2,    #(训练集数据1,训练集数据2,测试集数据1,测试集数据2)
                    pic_size,label1,label2,):    #(图片缩放大小,标签1,标签2)

    train_data_1,train_label_1 = picture_read.picture_document_deal(train_document1, pic_size, label1)   #sxs标签为0
    train_data_2,train_label_2 = picture_read.picture_document_deal(train_document2, pic_size, label2)   #xsx标签为1
    test_data_1,test_label_1 = picture_read.picture_document_deal(test_document1, pic_size, label1)   #sxs标签为0
    test_data_2,test_label_2 = picture_read.picture_document_deal(test_document2, pic_size, label2)   #xsx标签为1

    # numpy:vstack列上合并两个矩阵,hstack在行上合并两个矩阵
    train_data = np.vstack((train_data_1,train_data_2))
    train_labels = np.hstack((train_label_1,train_label_2))
    test_data = np.vstack((test_data_1,test_data_2))
    test_labels = np.hstack((test_label_1,test_label_2))
    print('train_data = ',train_data.shape)    #train_data,
    # print('train_labels = ', train_labels, train_labels.shape)

    # 二、 归一化,将每个点的像素值归一化到[0, 1]之间
    training_pics = train_data.reshape(train_data.shape[0], train_data.shape[1], train_data.shape[2], 1)
    test_pics = test_data.reshape(test_data.shape[0], test_data.shape[1], test_data.shape[2], 1)
    training_pics = training_pics / 255.0
    test_pics = test_pics / 255.0
    print('training_pics = ',training_pics.shape)      #training_pics,
    # print('test_pics = ',test_pics,test_pics.shape)

    print(training_pics.shape[1], training_pics.shape[2], training_pics.shape[3])
    print("training_pics = ", training_pics.shape)
    print("training_labels = ", train_labels.shape)

    return training_pics,test_pics,train_labels,test_labels


def model_structure(shape,classification ):   #(图片的格式,最终分类的结果)
    print('shape = ',shape)
    # 第一种写法
    # 定义模型,按顺序创建模型
    model_def = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(6, (3, 3), activation='relu', input_shape=(shape[1], shape[2], shape[3])),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Conv2D(16, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(120, activation='relu'),
        tf.keras.layers.Dense(84, activation='relu'),
        tf.keras.layers.Dense(classification, activation='softmax')
    ])

    # # 第二种写法
    # inputs = tf.keras.layers.Input(shape = shape[1:])
    # x = tf.keras.layers.Conv2D(filters=6, kernel_size=3, strides=1,
    #                                activation='relu', padding='same')(inputs)
    # x = tf.keras.layers.MaxPool2D(pool_size=2, strides=1, padding='same')(x)
    # x = tf.keras.layers.ReLU()(x)
    # x = tf.keras.layers.Conv2D(filters=16, kernel_size=3, strides=1,
    #                            activation='relu', padding='same')(x)
    # x = tf.keras.layers.MaxPool2D(pool_size=2, strides=1, padding='same')(x)
    # x = tf.keras.layers.ReLU()(x)
    # x = tf.keras.layers.Flatten()(x)
    # x = tf.keras.layers.Dense(120,activation='relu')(x)
    # x = tf.keras.layers.Dense(84,activation='relu')(x)
    # outputs = tf.keras.layers.Dense(classification,activation='softmax')(x)   # 二分类
    # model_def = tf.keras.Model(inputs = inputs, outputs=outputs)


    # 编译模型,确定优化方法,损失函数等
    model_def.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    return model_def

def model_train_save(model,model_file):     #(模型保存文件的名称)


    # 四、训练模型,并将每轮训练的历史信息保存在变量 history 中
    history = model.fit(x=training_images, y=train_label, batch_size = 32 ,
                        epochs = 20,validation_split=0.2)  # validation_split = 0.1   validation_data=(test_images, test_labels)
    # 计算损失
    test_loss = model.evaluate(test_images, test_label)
    # print('test_loss = ',test_loss)

    # 模型保存
    model.save(model_file)    # 保存全部

    # 删除模型
    del model

def model_predict_labels(model_file,test_pic,test_labels):  #测试集有标签时使用
    # 四、恢复模型并预测
    # print("模型读回")
    model = tf.keras.models.load_model(model_file)
    # 测试模型
    loss, acc = model.evaluate(test_pic, test_labels)

    # 五、预测
    predict = model.predict(test_pic)
    # print('predict = ',predict)
    predict = np.array(tf.argmax(predict, 1))
    print('predict = ',predict)   #tf.argmax根据axis取值的不同返回每行或者每列最大值的索引,axis = 1时取每行最大值的索引位置
    print('test_labels = ', test_labels)

    print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

    return predict

def model_predict_nolabels(model_file,test_pic):    #测试集无标签时使用
    # 四、恢复模型
    # print("模型读回")
    model = tf.keras.models.load_model(model_file)

    # 五、预测
    predict = model.predict(test_pic)
    # print('predict = ',predict)
    predict = np.array(tf.argmax(predict, 1))
    print('predict = ',predict)   #tf.argmax根据axis取值的不同返回每行或者每列最大值的索引,axis = 1时取每行最大值的索引位置

    return predict

if __name__ == '__main__':

    start_time = time.time()

    # 一、获取数据并处理数据
    im_size = (28, 28)   #图片放缩的大小
    train_document_1 = 'D:\\python_work\\model_project\\dataset\\train\\sxs'
    train_document_2 = 'D:\\python_work\\model_project\\dataset\\train\\xsx'
    test_document_1 = 'D:\\python_work\\model_project\\dataset\\test\\sxs'
    test_document_2 = 'D:\\python_work\\model_project\\dataset\\test\\xsx'

    training_images, test_images, train_label, test_label = data_processing(train_document_1,train_document_2,
                                                                            test_document_1,test_document_2,   #(训练集数据1,训练集数据2,测试集数据1,测试集数据2)
                                                                            im_size,0,1,)
    # 二、模型搭建:
    model_name = model_structure(training_images.shape,classification = 2)

    # 三、训练神经网络模型进行图片分类:
    model_train_save(model_name,'my_model.h5')


    # 四、模型读回及预测(有对比标签)
    predict_labels = model_predict_labels('my_model.h5',test_images,test_label)

    # # 五、模型读回及预测(无对比标签)
    # predict_labels = model_predict_nolabels('my_model.h5',test_images)


    end_time = time.time()
    print('程序运行时间:',end_time-start_time)

你可能感兴趣的:(TF2.0实现Lenet_5网络)