PyQt5结合神经网络进行物体分类

PyQt5结合神经网络进行物体分类
实习时做了这样一个小的Demo,过程中参考了很多博客和书籍。神经网络的主要程序是借鉴https://blog.csdn.net/jesmine_gu/article/details/81155787,侵删。

主体界面是这样的,可以实现的功能是可以在线训练、验证和测试
PyQt5结合神经网络进行物体分类_第1张图片
preWork.py

import os
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
from numpy import *
import cv2


def get_file_1(image_dir):
    image_list = []
    label_list = []
    # image_fold是一个保存了image_dir中所有文件夹名称的list
    image_fold = os.listdir(image_dir)
    # 获取每张图片的绝对地址,并将同一类的图片分配同样的数字标签
    # 图片的绝对地址放在image_list中
    # 对应的标签放在label_list中
    for index, image_path in enumerate(image_fold):
        image_name = os.listdir(image_dir + '/' + image_path)
        for each_image_name in image_name:
            image_full_path = image_dir + '/' + image_path + '/' + each_image_name  # 每张图片的绝对地址
            image_list.append(image_full_path)
            label_list.append(index)
    
    temp = np.array([image_list, label_list])   # temp是2*N的矩阵,N是图片的数量
    temp = temp.transpose()
    np.random.shuffle(temp)  # 将图片名称和标签捆绑着打乱
    all_image_list = list(temp[:, 0])  # N行一列
    all_label_list = list(temp[:, 1])  # N行一列
    all_label_list = [int(i) for i in all_label_list]  # (原文写的label_list,有误)
    # 返回的是打乱的图片的路径以及对应的标签
    # 都是以列的形式
    return all_image_list, all_label_list

# 将image和label转为list格式数据,因为后面用到的一些tensorflow函数接收的是list格式
# 为了方便网络训练,输入数据进行batch处理
# image_W,image_H:图像的宽度和高度
# batch_size:一次训练的图片数量
# capacity:一个列队最大多少
def get_batch(image, label, image_W, image_H, batch_size, capacity):
    # image: 图片的路径
    # tep1: 将上面生成的List传入get_batch(),转换类型,产生一个输入列队queue
    # tf.cast()用来做类型转化
    image = tf.cast(image, tf.string)   # 可变长度的字节数组,每一个张量都是一个字节数组
                                        # 将图片路径名称转换为字符串
    label = tf.cast(label, tf.int32)
    # tf.train.slice_input_producer是一个tensor生成器
    # 作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取一个tensor放入文件名列队
    ''''''
    # input_queue = tf.train.slice_input_producer([image, label], shuffle=False) # 一个路径对应一个标签
    input_queue = tf.train.slice_input_producer([image, label])
    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0]) # tf.read_file()从列队中读取图片,input_queue[0]是图片文件的路径名称


    # step2:将图像解码,使用相同类型的图像
    '''不同的图片格式对应的不同的decode函数'''
    image = tf.image.decode_jpeg(image_contents, channels=3)  # 不同的图片格式对应的不同的decode函数
    # image从这里开始是一张图片
    # step3: 数据预处理,对图像进行旋转,缩放,剪裁,归一化等操作,让计算出的模型更健壮
    ##image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H) # 这个函数是对图像进行裁剪或者填充来满足尺寸
    image = tf.image.resize_images(image, [image_W, image_H])
    ##image = cv2.resize(image, (image_W, image_H))
    # 对resize后的图像进行标准化处理
    image = tf.image.per_image_standardization(image)

    # step4: 生成batch
    # image_batch: 4D tensor [batch_size, width, height, 3] dtype = tf.float32
    # label_batch: 1D tensor [batch_size] dtype = tf.float32
    image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=16, capacity=capacity)

    # 重新排列label,行数为[batch_size]
    label_batch = tf.reshape(label_batch, [batch_size])  # tf.reshape(tensor, shape, name=None), shape必须是个list格式,所以需要用[]来表示
    image_batch = tf.cast(image_batch, tf.float32) # 显示灰度图
    # image_batch: [batch_size, height, weight, channels]
    return image_batch, label_batch


'''
# 单张图片验证能否打开
def PreWork():
    # 对预处理的数据进行可视化,查看预处理的结果
    IMG_W = 256
    IMG_H = 256
    BATCH_SIZE = 3
    CAPACITY = 64

    train_dir = 'C:/Users/JS/Desktop/MFC/Detect'
    image_list, label_list = get_file(train_dir)
    image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
    print(label_batch.shape)

    lists = ('good', 'bad')

    with tf.Session() as sess:
        i = 0
        coord = tf.train.Coordinator()   # 创建一个线程协调器,用来管理之后在Session中启动的所有线程
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop() and i < 1:
                img, label = sess.run([image_batch, label_batch])

                '''
                1、range()返回的是range object,而np.arange()返回的是numpy.ndarray()
                range(start, end, step),返回一个list对象,起始值为start,终止值为end,但不含终止值,步长为step。只能创建int型list。
                arange(start, end, step),与range()类似,但是返回一个array对象。需要引入import numpy as np,并且arange可以使用float型数据。

                2、range()不支持步长为小数,np.arange()支持步长为小数

                3、两者都可用于迭代
                range尽可用于迭代,而np.nrange作用远不止于此,它是一个序列,可被当做向量使用。
                '''
                for j in np.arange(BATCH_SIZE):
                    print('label: %d'%label[j])
                    plt.imshow(img[j, :, :, :])
                    title = lists[int(label[j])]
                    plt.title(title)
                    plt.show()
                i += 1
        except tf.errors.OutOfRangeError:
            print('Done')
        finally:
            coord.request_stop()
        coord.join(threads)

if __name__ == '__main__':
    PreWork()
'''


if __name__ == '__main__':
    image_dir = 'G:/PyProject/8/17flowers'
    image_list, label_list = get_file_1(image_dir)
    print(image_list)
    print(label_list)
    # print(temp)

DeepCNN.py

import tensorflow as tf

def weight_variable(shape, n):
    # tf.truncated_normal(shape, mean, stddev) 产生正态分布,均值和方差由自己设定
    initial = tf.truncated_normal(shape, stddev=n, dtype=tf.float32) #shape: 产生的矩阵的大小,[height, weight, channels, numbers]
    return initial

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape, dtype=tf.float32)
    return initial

def conv2(x, W):
    # tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
    # strides: [1, strides, strides ,1], 第一位和最后一位必须是1
    return tf.nn.conv2d(x, W, [1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x, name):
    # 池化卷积结果(conv2d)池化层采用kernel大小为3*3,步数也为2SAME:周围补0,取最大值。数据量缩小了4倍
    # x 是 CNN 第一步卷积的输出量,其shape必须为[batch, height, weight, channels];
    # ksize 是池化窗口的大小, shape为[batch, height, weight, channels]
    # stride 步长,一般是[1,stride, stride,1]
    # 池化层输出图像的大小为(W-f)/stride+1,向上取整
    return tf.nn.max_pool(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)


# 定义卷积网络
def deep_CNN(images, batch_size, n_classes):

    # 第一层网络
    with tf.variable_scope('conv1') as scope:        # 下面定义的变量名称都为:conv1/xxx
        # 第一卷积层
        w_conv1 = tf.Variable(weight_variable([3, 3, 3, 64], 0.1), name='weights', dtype=tf.float32)
        # w_conv1 = tf.Variable(tf.truncated_normal([3, 3, 3, 64], stddev=1.0), name='weight', dtype=tf.float32)
        # v = tf.get_variable('weights') v获得变量weights的值
        b_conv1 = tf.Variable(bias_variable([64]), name='bias', dtype=tf.float32)
        h_conv1 = tf.nn.relu(conv2(images, w_conv1) + b_conv1, name='conv1')
        # 第二卷积层
        w_conv1_1 = tf.Variable(weight_variable([3, 3, 64, 64], 0.1), name='weights1',dtype=tf.float32)
        b_conv1_1 = tf.Variable(bias_variable([64]), name='bias1',dtype=tf.float32)
        h_conv1 = tf.nn.relu(conv2(h_conv1, w_conv1_1) + b_conv1_1, name='conv1_1')

    # 第一层池化
    # 池化后做lrn(),局部响应归一化,增强模型泛化能力
    with tf.variable_scope('pooling1_lrn') as scope:
        pool1 = max_pool_2x2(h_conv1, name='pooling1')
        # 局部响应归一化:对局部神经元的活动创建活动竞争机制,使得其中响应比较大的值变得更大,抑制其他反馈较小的神经元
        norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')


    # 第二层网络
    with tf.variable_scope('conv2') as scope:
        # 第三卷积层
        w_conv2 = tf.Variable(weight_variable([3, 3, 64, 32], n=0.1), name='weights', dtype=tf.float32)
        b_conv2 = tf.Variable(bias_variable([32]), name='bias', dtype=tf.float32)
        h_conv2 = tf.nn.relu(conv2(norm1, w_conv2) + b_conv2, name='conv2')
        # 第四卷积层
        w_conv2_2 = tf.Variable(tf.truncated_normal([3, 3, 32, 32], 0.1), dtype=tf.float32)
        b_conv2_2 = tf.Variable(tf.truncated_normal([32], 0.1), dtype=tf.float32)
        h_conv2 = tf.nn.relu(conv2(h_conv2, w_conv2_2) + b_conv2_2, name='conv2_2')

    # 第二层池化层
    with tf.variable_scope('pooling2_lrn') as scope:
        pool2 = max_pool_2x2(h_conv2, name='pooling2')
        norm2 = tf.nn.lrn(pool2, depth_radius=4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')


    # 第三层网络
    with tf.variable_scope('conv3') as scope:
        # 第五卷积层
        w_conv3 = tf.Variable(weight_variable([3, 3, 32, 16], n=0.1), name='weights', dtype=tf.float32)
        b_conv3 = tf.Variable(bias_variable([16]), name='bias', dtype=tf.float32)
        h_conv3 = tf.nn.relu(conv2(norm2, w_conv3) + b_conv3, name='conv3')
        # 第六卷积层
        w_conv3_3 = tf.Variable(tf.truncated_normal([3, 3, 16, 16], 0.1), dtype=tf.float32)
        b_conv3_3 = tf.Variable(tf.truncated_normal([16], 0.1), dtype=tf.float32)
        h_conv3 = tf.nn.relu(conv2(h_conv3, w_conv3_3) + b_conv3_3, name='conv3_3')
    # 第二层池化层
    with tf.variable_scope('pooling3_lrn') as scope:
        pool3 = max_pool_2x2(h_conv3, name='pooling3')
        norm3 = tf.nn.lrn(pool3, depth_radius=4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')

    # 第四层全卷积层
    # 128个神经元,将之前pool层的输出reshape成一行,激活函数用relu
    with tf.variable_scope('fc1') as scope:
        reshape = tf.reshape(norm3, [batch_size, -1])    # norm3: [batch_size, w', h', c']
                                                         # 将feature map转换成batch_size * (w' * h' *c)的二维矩阵
                                                         # 行数为batch_size, 列数为w' * h' *c,相当于将每一张feature map变成一列
        dim = reshape.get_shape()[1].value               # dim = w'* h'* c'
        w_fc1 = tf.Variable(weight_variable([dim, 128], 0.005), name='weights', dtype= tf.float32)
        b_fc1 = tf.Variable(bias_variable([128]), name='bias', dtype=tf.float32)
        h_fc1 = tf.nn.relu(tf.matmul(reshape, w_fc1) + b_fc1, name=scope.name) # 全卷积层是进行矩阵相乘,而非卷积操作
        # 得到的是dim * 128的二维矩阵,对应每一张图变成一个128维的向量

    # 第五层全连接层
    with tf.variable_scope('fc2') as scope:
        w_fc2 = tf.Variable(weight_variable([128, 128], 0.005), name='weights', dtype=tf.float32)
        b_fc2 = tf.Variable(bias_variable([128]), name='bias', dtype=tf.float32)
        h_fc2 = tf.nn.relu(tf.matmul(h_fc1, w_fc2) + b_fc2, name=scope.name)


    # 对卷积结果进行dropout操作
    h_fc2_dropout = tf.nn.dropout(h_fc2, 0.5)
    # 自己设置dropout的比例
    # keep_prob = tf.placeholder(tf.float32)
    # h_fc2_dropout = tf.nn.dropput(h_fc2, keep_prob)

    # softmax回归层
    with tf.variable_scope('softmax') as scope:
        # n_classes = tf.placeholder(tf.int8)
        # n_classes 已经是形参
        weights = tf.Variable(weight_variable([128, n_classes], 0.005), name='softmax_linear', dtype=tf.float32)
        bias = tf.Variable(bias_variable([n_classes]), name='bias', dtype=tf.float32)
        softmax_linear = tf.add(tf.matmul(h_fc2_dropout, weights), bias, name='softmax_linear')
    return softmax_linear


def losses(logits, labels):
    with tf.variable_scope('losses') as scope:
        # tf.nn.sparse_softmax_cross_entropy_with_logitd(logits, labels)
        # logits: softmax层,[[0.958, 0.002, 0.104,..]
        #                    [.....], [.....]....]
        # labels: 标签,[1, 2, 3, 0, .....]
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='xentropy_per_example')
        loss = tf.reduce_mean(cross_entropy, name='loss')  #batch的损失求均值
        tf.summary.scalar(scope.name + '/loss', loss) # 产生节点以便在tensorboard中可视化
    return loss

# loss损失值优化
# 输入参数:loss。learning_rate,学习速率。
# 返回参数:train_op,训练op,这个参数要输入sess.run中让模型去训练。
def training(loss, learning_rate):
    with tf.name_scope('optimizer'):
        # learning_rate = [int(i) for i in learning_rate]
        # optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) # 是一个寻找全局最优点的优化算法,引入了二次方梯度校正
                                                                        # 相比SGD,不容易陷入局部最优点,速度更快
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        global_step = tf.Variable(0, name='global_step', trainable=False)
        train_op = optimizer.minimize(loss, global_step=global_step)
    return train_op


# 评价/准确率计算
# 输入参数:logits,网络计算值。labels,标签,也就是真实值
# 返回参数:accuracy,当前step的平均准确率,也就是在这些batch中多少张图片被正确分类了。
def evalution(logits, labels):
    with tf.variable_scope('accuracy') as scope:
        correct = tf.nn.in_top_k(logits, labels, 1)  # correct也是一个张量
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float16))
        tf.summary.scalar(scope.name + '/accuracy', accuracy)
    return accuracy


Train.py



def train(directory, num_classes):
    import os
    import tensorflow as tf
    import numpy as np
    from preWork import get_file, get_batch, get_file_1
    from DeepCNN import deep_CNN, losses, training, evalution

    # 变量申明
    N_CLASS = num_classes
    IMG_H = 28
    IMG_W = 28
    BATCH_SIZE = 64
    CAPACITY = 200
    MAX_STEP = 10001
    learning_rate = 0.01

    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)

    # 获取批次batch
    train_dir = directory
    # validation_dir = directory
    logs_train_dir = directory
    train, train_label = get_file_1(train_dir)  # 获得训练图片路径和标签
    train_batch, train_label_batch = get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
    print(train_batch)

    # 训练操作定义
    train_logits = deep_CNN(train_batch, BATCH_SIZE, N_CLASS)
    train_loss = losses(train_logits, train_label_batch)
    train_op = training(train_loss, learning_rate)
    train_acc = evalution(train_logits, train_label_batch)

    # 这个是log汇总记录
    summary_op = tf.summary.merge_all()

    sess = tf.Session()
    train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
    # 产生一个saver来存储训练好的模型
    saver = tf.train.Saver()
    # 所有节点初始化
    sess.run(tf.global_variables_initializer())
    # 队列监控
    coord = tf.train.Coordinator()  # 创建多线程协调器,用来管理之后在Session中启动的所有线程
    thread = tf.train.start_queue_runners(sess=sess, coord=coord)  # 这两个一般在一起用,还需要在最前面创建一个文件队列
    # tf.train.slice_input_producer([image, label])

    # 进行batch训练
    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break
            # 启动以下操作节点
            # _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])

            '''改变学习率'''
            ''''''
            if step < 1000:
                tra_loss, tra_acc = sess.run([train_loss, train_acc])
                _ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.1})
            elif step < 3000:
                tra_loss, tra_acc = sess.run([train_loss, train_acc])
                _ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.01})
            else:
                tra_loss, tra_acc = sess.run([train_loss, train_acc])
                _ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.001})
            ''''''
            '''
            # 改变学习率,失败
            if step <= 3000:
                replace_dict = {learning_rate: 0.01}
                # _ = sess.run(train_op, feed_dict=replace_dict)
                _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
            if step > 3000:
                replace_dict = {learning_rate: 0.001}
                # _ = sess.run(train_op, feed_dict=replace_dict)
                _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
            '''

            # 每隔100步打印一次当前的loss和acc,同时写入log,写入writer
            if step % 100 == 0:
                print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
                summary_str = sess.run(summary_op)
                train_writer.add_summary(summary_str, step)

            # 保存最后一次网络参数
            checkpoint_path = os.path.join(logs_train_dir, 'model')
            saver.save(sess, checkpoint_path)

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')

    finally:
        coord.request_stop()
    coord.join(thread)
    sess.close()

Test.py

import os
import tensorflow as tf
import numpy as np
from preWork import get_file, get_batch
from DeepCNN import training, evalution, deep_CNN
from PIL import Image
import matplotlib.pyplot as plt

# N_CLASS = 9
img_dir = 'G:/file/github/17flowers/1'
# log_dir = 'G:/PyProject/8/picture/train'
# lists = ['0', '1', '2', '3', '4', '5', '6', '7', '8']

def get_one_image(img_dir):
    imgs = os.listdir(img_dir)
    img_num = len(imgs)
    idn = np.random.randint(0, img_num)
    image = imgs[idn]
    image_dir = img_dir + '/' + image
    print(image_dir)
    image = Image.open(image_dir)
    plt.imshow(image)
    plt.show()
    image = image.resize([28, 28])
    image_arr = np.array(image)
    return image_arr


def test(image_arr, lists, log_dir, N_CLASS):
    with tf.Graph().as_default():
        # print(image_arr)
        image = tf.cast(image_arr, tf.float32)
        # print(image.shape)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 28, 28, 3])
        # print(image.shape)
        # print(image.dtype)
        p = deep_CNN(image, 1, N_CLASS)   # 输出softmax层
        logits = tf.nn.softmax(p)
        x = tf.placeholder(tf.float32, shape=[28, 28, 3])
        saver = tf.train.Saver()
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            # print('Loading Successfully')
        prediction = sess.run(logits, feed_dict={x: image_arr})
        max_index = np.argmax(prediction)
        # print('预测的标签为:' + str(max_index) + ' ' + str(lists[max_index]))
        # print('预测的准确率为:', prediction)
        return lists[max_index]


if __name__ == '__main__':
    print('Start test')
    img = get_one_image(img_dir)
    test(img)



Validation.py

import os
from preWork import get_file_1, get_batch
from DeepCNN import deep_CNN, losses, evalution
import tensorflow as tf
import numpy as np

BATCH_SIZE = 64
# N_CLASS = 17
IMG_W = 28
IMG_H = 28
CAPACITY = 60


# validation_dir = "G:/PyProject/20190715/17flowers/flowers"
# log_dir = 'G:/PyProject/20190715/17flowers'

def validation(validation_dir, log_dir, N_CLASS):
    BATCH_SIZE = 64
    # N_CLASS = 17
    IMG_W = 28
    IMG_H = 28
    CAPACITY = 60
    with tf.Graph().as_default():
        image, label = get_file_1(validation_dir)
        # print(label)
        image_batch, label_batch = get_batch(image, label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
        # image_batch = tf.reshape(image_batch, [BATCH_SIZE, 28, 28, 3])
        # print(image_batch.dtype)
        p = deep_CNN(image_batch, BATCH_SIZE, N_CLASS)
        validation_logits = tf.nn.softmax(p)
        x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 28, 28, 3])
        y = tf.placeholder(tf.float32, shape=[BATCH_SIZE, N_CLASS])
        z = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
        validation_acc = evalution(p, label_batch)

        '''加载训练完成的数据'''
        saver = tf.train.Saver()
        sess = tf.Session()
        ckpt = tf.train.get_checkpoint_state(log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        # sess.run(tf.initialize_all_variables())
        ''''''
        '''开启线程,启动队列'''
        coord = tf.train.Coordinator()
        thread = tf.train.start_queue_runners(sess=sess, coord=coord)
        ''''''
        # sv = tf.train.Supervisor() 也可以开启队列,但在本例中有问题

        # p, label_batch = sess.run([p, label_batch])
        '''
        label_batch_1, p_1 = sess.run([label_batch, p])

        print(label_batch_1)
        print(p_1)
        validation_accuracy = sess.run(validation_acc, feed_dict={y: p_1, z: label_batch_1})
        print(validation_accuracy)
        '''
        total_acc = 0
        for time in range(100):
            acc = sess.run(validation_acc)
            total_acc += acc

    return total_acc / 100


    # img_batch, lab_batch = sess.run([image_batch, label_batch])
    # v_acc = sess.run(validation_acc, feed_dict={valli_p: validation_logits})
    # print(v_acc)


if __name__ == '__main__':
    validation_dir = "G:/PyProject/20190715/17flowers/flowers"
    log_dir = 'G:/PyProject/20190715/17flowers'
    N_class = 17
    acc = validation(validation_dir, log_dir, N_class)
    print(acc)

CallMainWindow.py

import os
import sys
from Main import Ui_MainWindow
from PyQt5.QtWidgets import  *
from PyQt5.QtGui import *
from PyQt5.QtCore import *
# from Train import train
from Test import test
from PIL import Image
import tensorflow as tf
import numpy as np
import json
from Validation import validation

class MyMainWindow(QMainWindow, Ui_MainWindow):
    _signal = pyqtSignal(str)
    # _signal_times = pyqtSignal(int)
    def __init__(self):
        super(MyMainWindow, self).__init__()
        self.setupUi(self)

        self.directory = ''
        self.num_classes = ''
        self.savepath = ''
        self.lineEdit.setPlaceholderText("10000")
        self.learningtimes = 10000
        self.logfile = ''
        self.classes = ''
        self.lists = []
        self.validation_accuracy = ''
        self.valiadation_path = ''

        self.pushButton.clicked.connect(self.OpenTrainPath)
        self.pushButton_2.clicked.connect(self.StartTrain)
        self.pushButton_5.clicked.connect(self.SavePath)
        self.pushButton_3.clicked.connect(self.OpenPicture)
        self.pushButton_4.clicked.connect(self.LoadTrainFile)
        self._signal.connect(self.ShowAccuracy)
        # self._signal_times.connect(self.getLearningTimes)
        self.lineEdit.returnPressed.connect(self.GetTimes)
        # self.lineEdit_2.returnPressed.connect(self.GetNumClasses)
        self.pushButton_6.clicked.connect(self.GetClasses)
        self.pushButton_7.clicked.connect(self.Validation)


    def OpenTrainPath(self):
        self.directory = QFileDialog.getExistingDirectory(self, "选择文件夹")
        # print(self.directory)
        if len(self.directory) != 0:
            self.pushButton.setStyleSheet("color:red")
            self.num_classes = len(os.listdir(self.directory))
            for classes_list in os.listdir(self.directory):
                self.lists.append(classes_list)
            # print(self.lists)
            c_list = json.dumps(self.lists)
            list_save_path = self.directory + '.txt'
            a = open(list_save_path, 'w')
            a.write(c_list)
            a.close()
        # self.num_classes, _ = enumerate(os.listdir(self.directory))

    def SavePath(self):
        self.savepath = QFileDialog.getExistingDirectory(self, "选择文件夹")
        if len(self.savepath) != 0:
            self.pushButton_5.setStyleSheet("color:red")

    def GetTimes(self):
        # self.label_7.setText(self.lineEdit.text())
        self.learningtimes = int(self.lineEdit.text())

    def StartTrain(self):
        if self.directory != '' and self.savepath != '':
            self.pushButton_2.setStyleSheet("color:red")
            self.label_7.setText("Start Training")
            # print(self.directory)
            # print(self.num_classes)
            # print(self.learningtimes)
            self.train(directory=self.directory, num_classes=self.num_classes, save_path=self.savepath,
                       max_step=self.learningtimes)
        else:
            self.label_7.setText("请选择训练集和保存地址")

    def ShowAccuracy(self, str):
        self.label_7.setText(str)

    def train(self, directory, num_classes, save_path, max_step):
        import os
        import tensorflow as tf
        import numpy as np
        from preWork import get_file, get_batch, get_file_1
        from DeepCNN import deep_CNN, losses, training, evalution


        # 变量申明
        N_CLASS = num_classes
        IMG_H = 28
        IMG_W = 28
        BATCH_SIZE = 64
        CAPACITY = 200
        MAX_STEP = max_step
        learning_rate = 0.01

        x = tf.placeholder(tf.float32)
        y = tf.placeholder(tf.float32)

        # 获取批次batch
        train_dir = directory
        # validation_dir = directory
        logs_train_dir = save_path
        train, train_label = get_file_1(train_dir)  # 获得训练图片路径和标签
        train_batch, train_label_batch = get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
        print(train_batch)

        # 训练操作定义
        train_logits = deep_CNN(train_batch, BATCH_SIZE, N_CLASS)
        train_loss = losses(train_logits, train_label_batch)
        train_op = training(train_loss, learning_rate)
        train_acc = evalution(train_logits, train_label_batch)

        # 这个是log汇总记录
        summary_op = tf.summary.merge_all()

        sess = tf.Session()
        train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
        # 产生一个saver来存储训练好的模型
        saver = tf.train.Saver()
        # 所有节点初始化
        sess.run(tf.global_variables_initializer())
        # 队列监控
        coord = tf.train.Coordinator()  # 创建多线程协调器,用来管理之后在Session中启动的所有线程
        thread = tf.train.start_queue_runners(sess=sess, coord=coord)  # 这两个一般在一起用,还需要在最前面创建一个文件队列
        # tf.train.slice_input_producer([image, label])
        # self._signal.emit("Starting Training, please wait!")
        # 进行batch训练
        try:
            for step in np.arange(MAX_STEP):
                if coord.should_stop():
                    break
                # 启动以下操作节点
                # _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])

                '''改变学习率'''
                ''''''
                if step < 1000:
                    tra_loss, tra_acc = sess.run([train_loss, train_acc])
                    _ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.1})
                elif step < 3000:
                    tra_loss, tra_acc = sess.run([train_loss, train_acc])
                    _ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.01})
                else:
                    tra_loss, tra_acc = sess.run([train_loss, train_acc])
                    _ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.001})
                ''''''
                '''
                # 改变学习率,失败
                if step <= 3000:
                    replace_dict = {learning_rate: 0.01}
                    # _ = sess.run(train_op, feed_dict=replace_dict)
                    _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
                if step > 3000:
                    replace_dict = {learning_rate: 0.001}
                    # _ = sess.run(train_op, feed_dict=replace_dict)
                    _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
                '''

                # 每隔100步打印一次当前的loss和acc,同时写入log,写入writer
                if step % 100 == 0:
                    # print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
                    self._signal.emit('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))

                    summary_str = sess.run(summary_op)
                    train_writer.add_summary(summary_str, step)
                elif step == MAX_STEP-1:
                    self._signal.emit("Training Fininshed")


                # 保存最后一次网络参数
                checkpoint_path = os.path.join(logs_train_dir, 'model')
                saver.save(sess, checkpoint_path)
                QApplication.processEvents()   # ''' 这句话保证了程序在进行大事件处理时,可以刷新显示界面'''

        except tf.errors.OutOfRangeError:
            self._signal.emit('Done training -- epoch limit reached')

        finally:
            coord.request_stop()
        coord.join(thread)
        sess.close()

    def LoadTrainFile(self):
        self.logfile = QFileDialog.getExistingDirectory(self, "选择文件夹")
        if len(self.logfile) != 0:
            self.pushButton_4.setStyleSheet("color:red")
    '''
    def GetNumClasses(self):
        self.num_classes = int(self.lineEdit_2.text())
    '''
    def OpenPicture(self):
        fname, _ = QFileDialog.getOpenFileName(self, "选择图片", " ", "Image files(*.jpg *.bmp *.*)")
        if len(fname) != 0:
            self.pushButton_3.setStyleSheet("color:red")
            img = QPixmap(fname).scaled(self.label_3.width(), self.label_3.height())
            self.label_3.setPixmap(img)
            image = Image.open(fname)
            # image = tf.cast(image, tf.float32)
            image = np.array(image.resize([28, 28]))
            if self.lists == []:
                self.label_5.setText("请加载类别文件")
            else:
                self.classes = test(image_arr=image, lists=self.lists, log_dir=self.logfile, N_CLASS=self.num_classes)
                self.label_5.setText(self.classes)

    def GetClasses(self):
        fname, _ = QFileDialog.getOpenFileName(self, "选择文件", " ", "TXT(*.txt)")
        if len(fname) != 0:
            self.pushButton_6.setStyleSheet("color:red")
            temp = open(fname, 'r')
            self.lists = json.loads(temp.read())
            self.num_classes = len(self.lists)

    def Validation(self):
        self.label_7.setText("正在验证,请稍等")
        fname = QFileDialog.getExistingDirectory(self, "选择图片")
        # print(fname)
        if len(fname) != 0:
            self.pushButton_7.setStyleSheet("color:red")
            self.num_classes = len(os.listdir(fname))
            # print(self.num_classes)
            # print(self.savepath)
            self.validation_accuracy = validation(fname, self.logfile, self.num_classes)
            accuracy = '验证准确率为:' + str(self.validation_accuracy * 100) + '%'
            self.label_7.setText(accuracy)



if __name__ == "__main__":
    app = QApplication(sys.argv)
    win = MyMainWindow()
    win.show()
    sys.exit(app.exec_())

Main.py

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'Main.ui'
#
# Created by: PyQt5 UI code generator 5.11.3
#
# WARNING! All changes made in this file will be lost!

from PyQt5 import QtCore, QtGui, QtWidgets

class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(498, 440)
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.label_3 = QtWidgets.QLabel(self.centralwidget)
        self.label_3.setGeometry(QtCore.QRect(190, 90, 256, 256))
        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(self.label_3.sizePolicy().hasHeightForWidth())
        self.label_3.setSizePolicy(sizePolicy)
        self.label_3.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_3.setText("")
        self.label_3.setAlignment(QtCore.Qt.AlignCenter)
        self.label_3.setObjectName("label_3")
        self.label_3.setWordWrap(True)
        self.label_7 = QtWidgets.QLabel(self.centralwidget)
        self.label_7.setGeometry(QtCore.QRect(30, 250, 151, 71))
        self.label_7.setText("")
        self.label_7.setAlignment(QtCore.Qt.AlignCenter)
        self.label_7.setObjectName("label_7")
        self.layoutWidget = QtWidgets.QWidget(self.centralwidget)
        self.layoutWidget.setGeometry(QtCore.QRect(30, 31, 421, 72))
        self.layoutWidget.setObjectName("layoutWidget")
        self.gridLayout = QtWidgets.QGridLayout(self.layoutWidget)
        self.gridLayout.setContentsMargins(0, 0, 0, 0)
        self.gridLayout.setObjectName("gridLayout")
        self.label = QtWidgets.QLabel(self.layoutWidget)
        self.label.setAlignment(QtCore.Qt.AlignCenter)
        self.label.setObjectName("label")
        self.gridLayout.addWidget(self.label, 0, 0, 1, 1)
        self.label_9 = QtWidgets.QLabel(self.layoutWidget)
        self.label_9.setAlignment(QtCore.Qt.AlignCenter)
        self.label_9.setObjectName("label_9")
        self.gridLayout.addWidget(self.label_9, 0, 1, 1, 1)
        self.label_2 = QtWidgets.QLabel(self.layoutWidget)
        self.label_2.setAlignment(QtCore.Qt.AlignCenter)
        self.label_2.setObjectName("label_2")
        self.gridLayout.addWidget(self.label_2, 0, 2, 1, 1)
        self.pushButton = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton.setObjectName("pushButton")
        self.gridLayout.addWidget(self.pushButton, 1, 0, 1, 1)
        self.pushButton_5 = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton_5.setObjectName("pushButton_5")
        self.gridLayout.addWidget(self.pushButton_5, 1, 1, 1, 1)
        self.pushButton_3 = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton_3.setObjectName("pushButton_3")
        self.gridLayout.addWidget(self.pushButton_3, 1, 2, 1, 1)
        self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton_2.setObjectName("pushButton_2")
        self.gridLayout.addWidget(self.pushButton_2, 2, 0, 1, 1)
        self.layoutWidget1 = QtWidgets.QWidget(self.centralwidget)
        self.layoutWidget1.setGeometry(QtCore.QRect(250, 360, 161, 16))
        self.layoutWidget1.setObjectName("layoutWidget1")
        self.gridLayout_2 = QtWidgets.QGridLayout(self.layoutWidget1)
        self.gridLayout_2.setContentsMargins(0, 0, 0, 0)
        self.gridLayout_2.setObjectName("gridLayout_2")
        self.label_4 = QtWidgets.QLabel(self.layoutWidget1)
        self.label_4.setAlignment(QtCore.Qt.AlignCenter)
        self.label_4.setObjectName("label_4")
        self.gridLayout_2.addWidget(self.label_4, 0, 0, 1, 1)
        self.label_5 = QtWidgets.QLabel(self.layoutWidget1)
        self.label_5.setText("")
        self.label_5.setAlignment(QtCore.Qt.AlignCenter)
        self.label_5.setObjectName("label_5")
        self.gridLayout_2.addWidget(self.label_5, 0, 1, 1, 1)
        self.layoutWidget2 = QtWidgets.QWidget(self.centralwidget)
        self.layoutWidget2.setGeometry(QtCore.QRect(30, 120, 155, 109))
        self.layoutWidget2.setObjectName("layoutWidget2")
        self.formLayout = QtWidgets.QFormLayout(self.layoutWidget2)
        self.formLayout.setContentsMargins(0, 0, 0, 0)
        self.formLayout.setObjectName("formLayout")
        self.label_8 = QtWidgets.QLabel(self.layoutWidget2)
        self.label_8.setObjectName("label_8")
        self.formLayout.setWidget(0, QtWidgets.QFormLayout.LabelRole, self.label_8)
        self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget2)
        self.lineEdit.setObjectName("lineEdit")
        self.formLayout.setWidget(0, QtWidgets.QFormLayout.FieldRole, self.lineEdit)
        self.label_6 = QtWidgets.QLabel(self.layoutWidget2)
        self.label_6.setObjectName("label_6")
        self.formLayout.setWidget(1, QtWidgets.QFormLayout.LabelRole, self.label_6)
        self.pushButton_4 = QtWidgets.QPushButton(self.layoutWidget2)
        self.pushButton_4.setObjectName("pushButton_4")
        self.formLayout.setWidget(1, QtWidgets.QFormLayout.FieldRole, self.pushButton_4)
        self.label_10 = QtWidgets.QLabel(self.layoutWidget2)
        self.label_10.setAlignment(QtCore.Qt.AlignCenter)
        self.label_10.setObjectName("label_10")
        self.formLayout.setWidget(2, QtWidgets.QFormLayout.LabelRole, self.label_10)
        self.pushButton_6 = QtWidgets.QPushButton(self.layoutWidget2)
        self.pushButton_6.setObjectName("pushButton_6")
        self.formLayout.setWidget(2, QtWidgets.QFormLayout.FieldRole, self.pushButton_6)
        self.label_11 = QtWidgets.QLabel(self.layoutWidget2)
        self.label_11.setObjectName("label_11")
        self.formLayout.setWidget(3, QtWidgets.QFormLayout.LabelRole, self.label_11)
        self.pushButton_7 = QtWidgets.QPushButton(self.layoutWidget2)
        self.pushButton_7.setObjectName("pushButton_7")
        self.formLayout.setWidget(3, QtWidgets.QFormLayout.FieldRole, self.pushButton_7)
        MainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(MainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 498, 23))
        self.menubar.setObjectName("menubar")
        MainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)

        self.retranslateUi(MainWindow)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)

    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
        self.label.setText(_translate("MainWindow", "选择训练集文件夹"))
        self.label_9.setText(_translate("MainWindow", "选择训练数据保存位置"))
        self.label_2.setText(_translate("MainWindow", "选择待识别的图片"))
        self.pushButton.setText(_translate("MainWindow", "打开"))
        self.pushButton_5.setText(_translate("MainWindow", "打开"))
        self.pushButton_3.setText(_translate("MainWindow", "打开"))
        self.pushButton_2.setText(_translate("MainWindow", "开始训练"))
        self.label_4.setText(_translate("MainWindow", "类别"))
        self.label_8.setText(_translate("MainWindow", "训练次数"))
        self.label_6.setText(_translate("MainWindow", "加载训练数据"))
        self.pushButton_4.setText(_translate("MainWindow", "打开"))
        self.label_10.setText(_translate("MainWindow", "打开分类文件"))
        self.pushButton_6.setText(_translate("MainWindow", "打开"))
        self.label_11.setText(_translate("MainWindow", "打开验证集"))
        self.pushButton_7.setText(_translate("MainWindow", "打开"))


你可能感兴趣的:(1,OpenCV)