PyQt5结合神经网络进行物体分类
实习时做了这样一个小的Demo,过程中参考了很多博客和书籍。神经网络的主要程序是借鉴https://blog.csdn.net/jesmine_gu/article/details/81155787,侵删。
主体界面是这样的,可以实现的功能是可以在线训练、验证和测试
preWork.py
import os
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
from numpy import *
import cv2
def get_file_1(image_dir):
image_list = []
label_list = []
# image_fold是一个保存了image_dir中所有文件夹名称的list
image_fold = os.listdir(image_dir)
# 获取每张图片的绝对地址,并将同一类的图片分配同样的数字标签
# 图片的绝对地址放在image_list中
# 对应的标签放在label_list中
for index, image_path in enumerate(image_fold):
image_name = os.listdir(image_dir + '/' + image_path)
for each_image_name in image_name:
image_full_path = image_dir + '/' + image_path + '/' + each_image_name # 每张图片的绝对地址
image_list.append(image_full_path)
label_list.append(index)
temp = np.array([image_list, label_list]) # temp是2*N的矩阵,N是图片的数量
temp = temp.transpose()
np.random.shuffle(temp) # 将图片名称和标签捆绑着打乱
all_image_list = list(temp[:, 0]) # N行一列
all_label_list = list(temp[:, 1]) # N行一列
all_label_list = [int(i) for i in all_label_list] # (原文写的label_list,有误)
# 返回的是打乱的图片的路径以及对应的标签
# 都是以列的形式
return all_image_list, all_label_list
# 将image和label转为list格式数据,因为后面用到的一些tensorflow函数接收的是list格式
# 为了方便网络训练,输入数据进行batch处理
# image_W,image_H:图像的宽度和高度
# batch_size:一次训练的图片数量
# capacity:一个列队最大多少
def get_batch(image, label, image_W, image_H, batch_size, capacity):
# image: 图片的路径
# tep1: 将上面生成的List传入get_batch(),转换类型,产生一个输入列队queue
# tf.cast()用来做类型转化
image = tf.cast(image, tf.string) # 可变长度的字节数组,每一个张量都是一个字节数组
# 将图片路径名称转换为字符串
label = tf.cast(label, tf.int32)
# tf.train.slice_input_producer是一个tensor生成器
# 作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取一个tensor放入文件名列队
''''''
# input_queue = tf.train.slice_input_producer([image, label], shuffle=False) # 一个路径对应一个标签
input_queue = tf.train.slice_input_producer([image, label])
label = input_queue[1]
image_contents = tf.read_file(input_queue[0]) # tf.read_file()从列队中读取图片,input_queue[0]是图片文件的路径名称
# step2:将图像解码,使用相同类型的图像
'''不同的图片格式对应的不同的decode函数'''
image = tf.image.decode_jpeg(image_contents, channels=3) # 不同的图片格式对应的不同的decode函数
# image从这里开始是一张图片
# step3: 数据预处理,对图像进行旋转,缩放,剪裁,归一化等操作,让计算出的模型更健壮
##image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H) # 这个函数是对图像进行裁剪或者填充来满足尺寸
image = tf.image.resize_images(image, [image_W, image_H])
##image = cv2.resize(image, (image_W, image_H))
# 对resize后的图像进行标准化处理
image = tf.image.per_image_standardization(image)
# step4: 生成batch
# image_batch: 4D tensor [batch_size, width, height, 3] dtype = tf.float32
# label_batch: 1D tensor [batch_size] dtype = tf.float32
image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=16, capacity=capacity)
# 重新排列label,行数为[batch_size]
label_batch = tf.reshape(label_batch, [batch_size]) # tf.reshape(tensor, shape, name=None), shape必须是个list格式,所以需要用[]来表示
image_batch = tf.cast(image_batch, tf.float32) # 显示灰度图
# image_batch: [batch_size, height, weight, channels]
return image_batch, label_batch
'''
# 单张图片验证能否打开
def PreWork():
# 对预处理的数据进行可视化,查看预处理的结果
IMG_W = 256
IMG_H = 256
BATCH_SIZE = 3
CAPACITY = 64
train_dir = 'C:/Users/JS/Desktop/MFC/Detect'
image_list, label_list = get_file(train_dir)
image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
print(label_batch.shape)
lists = ('good', 'bad')
with tf.Session() as sess:
i = 0
coord = tf.train.Coordinator() # 创建一个线程协调器,用来管理之后在Session中启动的所有线程
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i < 1:
img, label = sess.run([image_batch, label_batch])
'''
1、range()返回的是range object,而np.arange()返回的是numpy.ndarray()
range(start, end, step),返回一个list对象,起始值为start,终止值为end,但不含终止值,步长为step。只能创建int型list。
arange(start, end, step),与range()类似,但是返回一个array对象。需要引入import numpy as np,并且arange可以使用float型数据。
2、range()不支持步长为小数,np.arange()支持步长为小数
3、两者都可用于迭代
range尽可用于迭代,而np.nrange作用远不止于此,它是一个序列,可被当做向量使用。
'''
for j in np.arange(BATCH_SIZE):
print('label: %d'%label[j])
plt.imshow(img[j, :, :, :])
title = lists[int(label[j])]
plt.title(title)
plt.show()
i += 1
except tf.errors.OutOfRangeError:
print('Done')
finally:
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
PreWork()
'''
if __name__ == '__main__':
image_dir = 'G:/PyProject/8/17flowers'
image_list, label_list = get_file_1(image_dir)
print(image_list)
print(label_list)
# print(temp)
DeepCNN.py
import tensorflow as tf
def weight_variable(shape, n):
# tf.truncated_normal(shape, mean, stddev) 产生正态分布,均值和方差由自己设定
initial = tf.truncated_normal(shape, stddev=n, dtype=tf.float32) #shape: 产生的矩阵的大小,[height, weight, channels, numbers]
return initial
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape, dtype=tf.float32)
return initial
def conv2(x, W):
# tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
# strides: [1, strides, strides ,1], 第一位和最后一位必须是1
return tf.nn.conv2d(x, W, [1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x, name):
# 池化卷积结果(conv2d)池化层采用kernel大小为3*3,步数也为2,SAME:周围补0,取最大值。数据量缩小了4倍
# x 是 CNN 第一步卷积的输出量,其shape必须为[batch, height, weight, channels];
# ksize 是池化窗口的大小, shape为[batch, height, weight, channels]
# stride 步长,一般是[1,stride, stride,1]
# 池化层输出图像的大小为(W-f)/stride+1,向上取整
return tf.nn.max_pool(x, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
# 定义卷积网络
def deep_CNN(images, batch_size, n_classes):
# 第一层网络
with tf.variable_scope('conv1') as scope: # 下面定义的变量名称都为:conv1/xxx
# 第一卷积层
w_conv1 = tf.Variable(weight_variable([3, 3, 3, 64], 0.1), name='weights', dtype=tf.float32)
# w_conv1 = tf.Variable(tf.truncated_normal([3, 3, 3, 64], stddev=1.0), name='weight', dtype=tf.float32)
# v = tf.get_variable('weights') v获得变量weights的值
b_conv1 = tf.Variable(bias_variable([64]), name='bias', dtype=tf.float32)
h_conv1 = tf.nn.relu(conv2(images, w_conv1) + b_conv1, name='conv1')
# 第二卷积层
w_conv1_1 = tf.Variable(weight_variable([3, 3, 64, 64], 0.1), name='weights1',dtype=tf.float32)
b_conv1_1 = tf.Variable(bias_variable([64]), name='bias1',dtype=tf.float32)
h_conv1 = tf.nn.relu(conv2(h_conv1, w_conv1_1) + b_conv1_1, name='conv1_1')
# 第一层池化
# 池化后做lrn(),局部响应归一化,增强模型泛化能力
with tf.variable_scope('pooling1_lrn') as scope:
pool1 = max_pool_2x2(h_conv1, name='pooling1')
# 局部响应归一化:对局部神经元的活动创建活动竞争机制,使得其中响应比较大的值变得更大,抑制其他反馈较小的神经元
norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
# 第二层网络
with tf.variable_scope('conv2') as scope:
# 第三卷积层
w_conv2 = tf.Variable(weight_variable([3, 3, 64, 32], n=0.1), name='weights', dtype=tf.float32)
b_conv2 = tf.Variable(bias_variable([32]), name='bias', dtype=tf.float32)
h_conv2 = tf.nn.relu(conv2(norm1, w_conv2) + b_conv2, name='conv2')
# 第四卷积层
w_conv2_2 = tf.Variable(tf.truncated_normal([3, 3, 32, 32], 0.1), dtype=tf.float32)
b_conv2_2 = tf.Variable(tf.truncated_normal([32], 0.1), dtype=tf.float32)
h_conv2 = tf.nn.relu(conv2(h_conv2, w_conv2_2) + b_conv2_2, name='conv2_2')
# 第二层池化层
with tf.variable_scope('pooling2_lrn') as scope:
pool2 = max_pool_2x2(h_conv2, name='pooling2')
norm2 = tf.nn.lrn(pool2, depth_radius=4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
# 第三层网络
with tf.variable_scope('conv3') as scope:
# 第五卷积层
w_conv3 = tf.Variable(weight_variable([3, 3, 32, 16], n=0.1), name='weights', dtype=tf.float32)
b_conv3 = tf.Variable(bias_variable([16]), name='bias', dtype=tf.float32)
h_conv3 = tf.nn.relu(conv2(norm2, w_conv3) + b_conv3, name='conv3')
# 第六卷积层
w_conv3_3 = tf.Variable(tf.truncated_normal([3, 3, 16, 16], 0.1), dtype=tf.float32)
b_conv3_3 = tf.Variable(tf.truncated_normal([16], 0.1), dtype=tf.float32)
h_conv3 = tf.nn.relu(conv2(h_conv3, w_conv3_3) + b_conv3_3, name='conv3_3')
# 第二层池化层
with tf.variable_scope('pooling3_lrn') as scope:
pool3 = max_pool_2x2(h_conv3, name='pooling3')
norm3 = tf.nn.lrn(pool3, depth_radius=4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
# 第四层全卷积层
# 128个神经元,将之前pool层的输出reshape成一行,激活函数用relu
with tf.variable_scope('fc1') as scope:
reshape = tf.reshape(norm3, [batch_size, -1]) # norm3: [batch_size, w', h', c']
# 将feature map转换成batch_size * (w' * h' *c)的二维矩阵
# 行数为batch_size, 列数为w' * h' *c,相当于将每一张feature map变成一列
dim = reshape.get_shape()[1].value # dim = w'* h'* c'
w_fc1 = tf.Variable(weight_variable([dim, 128], 0.005), name='weights', dtype= tf.float32)
b_fc1 = tf.Variable(bias_variable([128]), name='bias', dtype=tf.float32)
h_fc1 = tf.nn.relu(tf.matmul(reshape, w_fc1) + b_fc1, name=scope.name) # 全卷积层是进行矩阵相乘,而非卷积操作
# 得到的是dim * 128的二维矩阵,对应每一张图变成一个128维的向量
# 第五层全连接层
with tf.variable_scope('fc2') as scope:
w_fc2 = tf.Variable(weight_variable([128, 128], 0.005), name='weights', dtype=tf.float32)
b_fc2 = tf.Variable(bias_variable([128]), name='bias', dtype=tf.float32)
h_fc2 = tf.nn.relu(tf.matmul(h_fc1, w_fc2) + b_fc2, name=scope.name)
# 对卷积结果进行dropout操作
h_fc2_dropout = tf.nn.dropout(h_fc2, 0.5)
# 自己设置dropout的比例
# keep_prob = tf.placeholder(tf.float32)
# h_fc2_dropout = tf.nn.dropput(h_fc2, keep_prob)
# softmax回归层
with tf.variable_scope('softmax') as scope:
# n_classes = tf.placeholder(tf.int8)
# n_classes 已经是形参
weights = tf.Variable(weight_variable([128, n_classes], 0.005), name='softmax_linear', dtype=tf.float32)
bias = tf.Variable(bias_variable([n_classes]), name='bias', dtype=tf.float32)
softmax_linear = tf.add(tf.matmul(h_fc2_dropout, weights), bias, name='softmax_linear')
return softmax_linear
def losses(logits, labels):
with tf.variable_scope('losses') as scope:
# tf.nn.sparse_softmax_cross_entropy_with_logitd(logits, labels)
# logits: softmax层,[[0.958, 0.002, 0.104,..]
# [.....], [.....]....]
# labels: 标签,[1, 2, 3, 0, .....]
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='xentropy_per_example')
loss = tf.reduce_mean(cross_entropy, name='loss') #batch的损失求均值
tf.summary.scalar(scope.name + '/loss', loss) # 产生节点以便在tensorboard中可视化
return loss
# loss损失值优化
# 输入参数:loss。learning_rate,学习速率。
# 返回参数:train_op,训练op,这个参数要输入sess.run中让模型去训练。
def training(loss, learning_rate):
with tf.name_scope('optimizer'):
# learning_rate = [int(i) for i in learning_rate]
# optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) # 是一个寻找全局最优点的优化算法,引入了二次方梯度校正
# 相比SGD,不容易陷入局部最优点,速度更快
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
return train_op
# 评价/准确率计算
# 输入参数:logits,网络计算值。labels,标签,也就是真实值
# 返回参数:accuracy,当前step的平均准确率,也就是在这些batch中多少张图片被正确分类了。
def evalution(logits, labels):
with tf.variable_scope('accuracy') as scope:
correct = tf.nn.in_top_k(logits, labels, 1) # correct也是一个张量
accuracy = tf.reduce_mean(tf.cast(correct, tf.float16))
tf.summary.scalar(scope.name + '/accuracy', accuracy)
return accuracy
Train.py
def train(directory, num_classes):
import os
import tensorflow as tf
import numpy as np
from preWork import get_file, get_batch, get_file_1
from DeepCNN import deep_CNN, losses, training, evalution
# 变量申明
N_CLASS = num_classes
IMG_H = 28
IMG_W = 28
BATCH_SIZE = 64
CAPACITY = 200
MAX_STEP = 10001
learning_rate = 0.01
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
# 获取批次batch
train_dir = directory
# validation_dir = directory
logs_train_dir = directory
train, train_label = get_file_1(train_dir) # 获得训练图片路径和标签
train_batch, train_label_batch = get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
print(train_batch)
# 训练操作定义
train_logits = deep_CNN(train_batch, BATCH_SIZE, N_CLASS)
train_loss = losses(train_logits, train_label_batch)
train_op = training(train_loss, learning_rate)
train_acc = evalution(train_logits, train_label_batch)
# 这个是log汇总记录
summary_op = tf.summary.merge_all()
sess = tf.Session()
train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
# 产生一个saver来存储训练好的模型
saver = tf.train.Saver()
# 所有节点初始化
sess.run(tf.global_variables_initializer())
# 队列监控
coord = tf.train.Coordinator() # 创建多线程协调器,用来管理之后在Session中启动的所有线程
thread = tf.train.start_queue_runners(sess=sess, coord=coord) # 这两个一般在一起用,还需要在最前面创建一个文件队列
# tf.train.slice_input_producer([image, label])
# 进行batch训练
try:
for step in np.arange(MAX_STEP):
if coord.should_stop():
break
# 启动以下操作节点
# _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])
'''改变学习率'''
''''''
if step < 1000:
tra_loss, tra_acc = sess.run([train_loss, train_acc])
_ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.1})
elif step < 3000:
tra_loss, tra_acc = sess.run([train_loss, train_acc])
_ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.01})
else:
tra_loss, tra_acc = sess.run([train_loss, train_acc])
_ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.001})
''''''
'''
# 改变学习率,失败
if step <= 3000:
replace_dict = {learning_rate: 0.01}
# _ = sess.run(train_op, feed_dict=replace_dict)
_, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
if step > 3000:
replace_dict = {learning_rate: 0.001}
# _ = sess.run(train_op, feed_dict=replace_dict)
_, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
'''
# 每隔100步打印一次当前的loss和acc,同时写入log,写入writer
if step % 100 == 0:
print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
summary_str = sess.run(summary_op)
train_writer.add_summary(summary_str, step)
# 保存最后一次网络参数
checkpoint_path = os.path.join(logs_train_dir, 'model')
saver.save(sess, checkpoint_path)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(thread)
sess.close()
Test.py
import os
import tensorflow as tf
import numpy as np
from preWork import get_file, get_batch
from DeepCNN import training, evalution, deep_CNN
from PIL import Image
import matplotlib.pyplot as plt
# N_CLASS = 9
img_dir = 'G:/file/github/17flowers/1'
# log_dir = 'G:/PyProject/8/picture/train'
# lists = ['0', '1', '2', '3', '4', '5', '6', '7', '8']
def get_one_image(img_dir):
imgs = os.listdir(img_dir)
img_num = len(imgs)
idn = np.random.randint(0, img_num)
image = imgs[idn]
image_dir = img_dir + '/' + image
print(image_dir)
image = Image.open(image_dir)
plt.imshow(image)
plt.show()
image = image.resize([28, 28])
image_arr = np.array(image)
return image_arr
def test(image_arr, lists, log_dir, N_CLASS):
with tf.Graph().as_default():
# print(image_arr)
image = tf.cast(image_arr, tf.float32)
# print(image.shape)
image = tf.image.per_image_standardization(image)
image = tf.reshape(image, [1, 28, 28, 3])
# print(image.shape)
# print(image.dtype)
p = deep_CNN(image, 1, N_CLASS) # 输出softmax层
logits = tf.nn.softmax(p)
x = tf.placeholder(tf.float32, shape=[28, 28, 3])
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(log_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# print('Loading Successfully')
prediction = sess.run(logits, feed_dict={x: image_arr})
max_index = np.argmax(prediction)
# print('预测的标签为:' + str(max_index) + ' ' + str(lists[max_index]))
# print('预测的准确率为:', prediction)
return lists[max_index]
if __name__ == '__main__':
print('Start test')
img = get_one_image(img_dir)
test(img)
Validation.py
import os
from preWork import get_file_1, get_batch
from DeepCNN import deep_CNN, losses, evalution
import tensorflow as tf
import numpy as np
BATCH_SIZE = 64
# N_CLASS = 17
IMG_W = 28
IMG_H = 28
CAPACITY = 60
# validation_dir = "G:/PyProject/20190715/17flowers/flowers"
# log_dir = 'G:/PyProject/20190715/17flowers'
def validation(validation_dir, log_dir, N_CLASS):
BATCH_SIZE = 64
# N_CLASS = 17
IMG_W = 28
IMG_H = 28
CAPACITY = 60
with tf.Graph().as_default():
image, label = get_file_1(validation_dir)
# print(label)
image_batch, label_batch = get_batch(image, label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
# image_batch = tf.reshape(image_batch, [BATCH_SIZE, 28, 28, 3])
# print(image_batch.dtype)
p = deep_CNN(image_batch, BATCH_SIZE, N_CLASS)
validation_logits = tf.nn.softmax(p)
x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 28, 28, 3])
y = tf.placeholder(tf.float32, shape=[BATCH_SIZE, N_CLASS])
z = tf.placeholder(tf.float32, shape=[BATCH_SIZE])
validation_acc = evalution(p, label_batch)
'''加载训练完成的数据'''
saver = tf.train.Saver()
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(log_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# sess.run(tf.initialize_all_variables())
''''''
'''开启线程,启动队列'''
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess=sess, coord=coord)
''''''
# sv = tf.train.Supervisor() 也可以开启队列,但在本例中有问题
# p, label_batch = sess.run([p, label_batch])
'''
label_batch_1, p_1 = sess.run([label_batch, p])
print(label_batch_1)
print(p_1)
validation_accuracy = sess.run(validation_acc, feed_dict={y: p_1, z: label_batch_1})
print(validation_accuracy)
'''
total_acc = 0
for time in range(100):
acc = sess.run(validation_acc)
total_acc += acc
return total_acc / 100
# img_batch, lab_batch = sess.run([image_batch, label_batch])
# v_acc = sess.run(validation_acc, feed_dict={valli_p: validation_logits})
# print(v_acc)
if __name__ == '__main__':
validation_dir = "G:/PyProject/20190715/17flowers/flowers"
log_dir = 'G:/PyProject/20190715/17flowers'
N_class = 17
acc = validation(validation_dir, log_dir, N_class)
print(acc)
CallMainWindow.py
import os
import sys
from Main import Ui_MainWindow
from PyQt5.QtWidgets import *
from PyQt5.QtGui import *
from PyQt5.QtCore import *
# from Train import train
from Test import test
from PIL import Image
import tensorflow as tf
import numpy as np
import json
from Validation import validation
class MyMainWindow(QMainWindow, Ui_MainWindow):
_signal = pyqtSignal(str)
# _signal_times = pyqtSignal(int)
def __init__(self):
super(MyMainWindow, self).__init__()
self.setupUi(self)
self.directory = ''
self.num_classes = ''
self.savepath = ''
self.lineEdit.setPlaceholderText("10000")
self.learningtimes = 10000
self.logfile = ''
self.classes = ''
self.lists = []
self.validation_accuracy = ''
self.valiadation_path = ''
self.pushButton.clicked.connect(self.OpenTrainPath)
self.pushButton_2.clicked.connect(self.StartTrain)
self.pushButton_5.clicked.connect(self.SavePath)
self.pushButton_3.clicked.connect(self.OpenPicture)
self.pushButton_4.clicked.connect(self.LoadTrainFile)
self._signal.connect(self.ShowAccuracy)
# self._signal_times.connect(self.getLearningTimes)
self.lineEdit.returnPressed.connect(self.GetTimes)
# self.lineEdit_2.returnPressed.connect(self.GetNumClasses)
self.pushButton_6.clicked.connect(self.GetClasses)
self.pushButton_7.clicked.connect(self.Validation)
def OpenTrainPath(self):
self.directory = QFileDialog.getExistingDirectory(self, "选择文件夹")
# print(self.directory)
if len(self.directory) != 0:
self.pushButton.setStyleSheet("color:red")
self.num_classes = len(os.listdir(self.directory))
for classes_list in os.listdir(self.directory):
self.lists.append(classes_list)
# print(self.lists)
c_list = json.dumps(self.lists)
list_save_path = self.directory + '.txt'
a = open(list_save_path, 'w')
a.write(c_list)
a.close()
# self.num_classes, _ = enumerate(os.listdir(self.directory))
def SavePath(self):
self.savepath = QFileDialog.getExistingDirectory(self, "选择文件夹")
if len(self.savepath) != 0:
self.pushButton_5.setStyleSheet("color:red")
def GetTimes(self):
# self.label_7.setText(self.lineEdit.text())
self.learningtimes = int(self.lineEdit.text())
def StartTrain(self):
if self.directory != '' and self.savepath != '':
self.pushButton_2.setStyleSheet("color:red")
self.label_7.setText("Start Training")
# print(self.directory)
# print(self.num_classes)
# print(self.learningtimes)
self.train(directory=self.directory, num_classes=self.num_classes, save_path=self.savepath,
max_step=self.learningtimes)
else:
self.label_7.setText("请选择训练集和保存地址")
def ShowAccuracy(self, str):
self.label_7.setText(str)
def train(self, directory, num_classes, save_path, max_step):
import os
import tensorflow as tf
import numpy as np
from preWork import get_file, get_batch, get_file_1
from DeepCNN import deep_CNN, losses, training, evalution
# 变量申明
N_CLASS = num_classes
IMG_H = 28
IMG_W = 28
BATCH_SIZE = 64
CAPACITY = 200
MAX_STEP = max_step
learning_rate = 0.01
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
# 获取批次batch
train_dir = directory
# validation_dir = directory
logs_train_dir = save_path
train, train_label = get_file_1(train_dir) # 获得训练图片路径和标签
train_batch, train_label_batch = get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
print(train_batch)
# 训练操作定义
train_logits = deep_CNN(train_batch, BATCH_SIZE, N_CLASS)
train_loss = losses(train_logits, train_label_batch)
train_op = training(train_loss, learning_rate)
train_acc = evalution(train_logits, train_label_batch)
# 这个是log汇总记录
summary_op = tf.summary.merge_all()
sess = tf.Session()
train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
# 产生一个saver来存储训练好的模型
saver = tf.train.Saver()
# 所有节点初始化
sess.run(tf.global_variables_initializer())
# 队列监控
coord = tf.train.Coordinator() # 创建多线程协调器,用来管理之后在Session中启动的所有线程
thread = tf.train.start_queue_runners(sess=sess, coord=coord) # 这两个一般在一起用,还需要在最前面创建一个文件队列
# tf.train.slice_input_producer([image, label])
# self._signal.emit("Starting Training, please wait!")
# 进行batch训练
try:
for step in np.arange(MAX_STEP):
if coord.should_stop():
break
# 启动以下操作节点
# _, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc])
'''改变学习率'''
''''''
if step < 1000:
tra_loss, tra_acc = sess.run([train_loss, train_acc])
_ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.1})
elif step < 3000:
tra_loss, tra_acc = sess.run([train_loss, train_acc])
_ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.01})
else:
tra_loss, tra_acc = sess.run([train_loss, train_acc])
_ = sess.run(train_op, feed_dict={x: tra_loss, y: 0.001})
''''''
'''
# 改变学习率,失败
if step <= 3000:
replace_dict = {learning_rate: 0.01}
# _ = sess.run(train_op, feed_dict=replace_dict)
_, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
if step > 3000:
replace_dict = {learning_rate: 0.001}
# _ = sess.run(train_op, feed_dict=replace_dict)
_, tra_loss, tra_acc = sess.run([train_op, train_loss, train_acc], feed_dict=replace_dict)
'''
# 每隔100步打印一次当前的loss和acc,同时写入log,写入writer
if step % 100 == 0:
# print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
self._signal.emit('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
summary_str = sess.run(summary_op)
train_writer.add_summary(summary_str, step)
elif step == MAX_STEP-1:
self._signal.emit("Training Fininshed")
# 保存最后一次网络参数
checkpoint_path = os.path.join(logs_train_dir, 'model')
saver.save(sess, checkpoint_path)
QApplication.processEvents() # ''' 这句话保证了程序在进行大事件处理时,可以刷新显示界面'''
except tf.errors.OutOfRangeError:
self._signal.emit('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(thread)
sess.close()
def LoadTrainFile(self):
self.logfile = QFileDialog.getExistingDirectory(self, "选择文件夹")
if len(self.logfile) != 0:
self.pushButton_4.setStyleSheet("color:red")
'''
def GetNumClasses(self):
self.num_classes = int(self.lineEdit_2.text())
'''
def OpenPicture(self):
fname, _ = QFileDialog.getOpenFileName(self, "选择图片", " ", "Image files(*.jpg *.bmp *.*)")
if len(fname) != 0:
self.pushButton_3.setStyleSheet("color:red")
img = QPixmap(fname).scaled(self.label_3.width(), self.label_3.height())
self.label_3.setPixmap(img)
image = Image.open(fname)
# image = tf.cast(image, tf.float32)
image = np.array(image.resize([28, 28]))
if self.lists == []:
self.label_5.setText("请加载类别文件")
else:
self.classes = test(image_arr=image, lists=self.lists, log_dir=self.logfile, N_CLASS=self.num_classes)
self.label_5.setText(self.classes)
def GetClasses(self):
fname, _ = QFileDialog.getOpenFileName(self, "选择文件", " ", "TXT(*.txt)")
if len(fname) != 0:
self.pushButton_6.setStyleSheet("color:red")
temp = open(fname, 'r')
self.lists = json.loads(temp.read())
self.num_classes = len(self.lists)
def Validation(self):
self.label_7.setText("正在验证,请稍等")
fname = QFileDialog.getExistingDirectory(self, "选择图片")
# print(fname)
if len(fname) != 0:
self.pushButton_7.setStyleSheet("color:red")
self.num_classes = len(os.listdir(fname))
# print(self.num_classes)
# print(self.savepath)
self.validation_accuracy = validation(fname, self.logfile, self.num_classes)
accuracy = '验证准确率为:' + str(self.validation_accuracy * 100) + '%'
self.label_7.setText(accuracy)
if __name__ == "__main__":
app = QApplication(sys.argv)
win = MyMainWindow()
win.show()
sys.exit(app.exec_())
Main.py
# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'Main.ui'
#
# Created by: PyQt5 UI code generator 5.11.3
#
# WARNING! All changes made in this file will be lost!
from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(498, 440)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.label_3 = QtWidgets.QLabel(self.centralwidget)
self.label_3.setGeometry(QtCore.QRect(190, 90, 256, 256))
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(self.label_3.sizePolicy().hasHeightForWidth())
self.label_3.setSizePolicy(sizePolicy)
self.label_3.setLayoutDirection(QtCore.Qt.LeftToRight)
self.label_3.setText("")
self.label_3.setAlignment(QtCore.Qt.AlignCenter)
self.label_3.setObjectName("label_3")
self.label_3.setWordWrap(True)
self.label_7 = QtWidgets.QLabel(self.centralwidget)
self.label_7.setGeometry(QtCore.QRect(30, 250, 151, 71))
self.label_7.setText("")
self.label_7.setAlignment(QtCore.Qt.AlignCenter)
self.label_7.setObjectName("label_7")
self.layoutWidget = QtWidgets.QWidget(self.centralwidget)
self.layoutWidget.setGeometry(QtCore.QRect(30, 31, 421, 72))
self.layoutWidget.setObjectName("layoutWidget")
self.gridLayout = QtWidgets.QGridLayout(self.layoutWidget)
self.gridLayout.setContentsMargins(0, 0, 0, 0)
self.gridLayout.setObjectName("gridLayout")
self.label = QtWidgets.QLabel(self.layoutWidget)
self.label.setAlignment(QtCore.Qt.AlignCenter)
self.label.setObjectName("label")
self.gridLayout.addWidget(self.label, 0, 0, 1, 1)
self.label_9 = QtWidgets.QLabel(self.layoutWidget)
self.label_9.setAlignment(QtCore.Qt.AlignCenter)
self.label_9.setObjectName("label_9")
self.gridLayout.addWidget(self.label_9, 0, 1, 1, 1)
self.label_2 = QtWidgets.QLabel(self.layoutWidget)
self.label_2.setAlignment(QtCore.Qt.AlignCenter)
self.label_2.setObjectName("label_2")
self.gridLayout.addWidget(self.label_2, 0, 2, 1, 1)
self.pushButton = QtWidgets.QPushButton(self.layoutWidget)
self.pushButton.setObjectName("pushButton")
self.gridLayout.addWidget(self.pushButton, 1, 0, 1, 1)
self.pushButton_5 = QtWidgets.QPushButton(self.layoutWidget)
self.pushButton_5.setObjectName("pushButton_5")
self.gridLayout.addWidget(self.pushButton_5, 1, 1, 1, 1)
self.pushButton_3 = QtWidgets.QPushButton(self.layoutWidget)
self.pushButton_3.setObjectName("pushButton_3")
self.gridLayout.addWidget(self.pushButton_3, 1, 2, 1, 1)
self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget)
self.pushButton_2.setObjectName("pushButton_2")
self.gridLayout.addWidget(self.pushButton_2, 2, 0, 1, 1)
self.layoutWidget1 = QtWidgets.QWidget(self.centralwidget)
self.layoutWidget1.setGeometry(QtCore.QRect(250, 360, 161, 16))
self.layoutWidget1.setObjectName("layoutWidget1")
self.gridLayout_2 = QtWidgets.QGridLayout(self.layoutWidget1)
self.gridLayout_2.setContentsMargins(0, 0, 0, 0)
self.gridLayout_2.setObjectName("gridLayout_2")
self.label_4 = QtWidgets.QLabel(self.layoutWidget1)
self.label_4.setAlignment(QtCore.Qt.AlignCenter)
self.label_4.setObjectName("label_4")
self.gridLayout_2.addWidget(self.label_4, 0, 0, 1, 1)
self.label_5 = QtWidgets.QLabel(self.layoutWidget1)
self.label_5.setText("")
self.label_5.setAlignment(QtCore.Qt.AlignCenter)
self.label_5.setObjectName("label_5")
self.gridLayout_2.addWidget(self.label_5, 0, 1, 1, 1)
self.layoutWidget2 = QtWidgets.QWidget(self.centralwidget)
self.layoutWidget2.setGeometry(QtCore.QRect(30, 120, 155, 109))
self.layoutWidget2.setObjectName("layoutWidget2")
self.formLayout = QtWidgets.QFormLayout(self.layoutWidget2)
self.formLayout.setContentsMargins(0, 0, 0, 0)
self.formLayout.setObjectName("formLayout")
self.label_8 = QtWidgets.QLabel(self.layoutWidget2)
self.label_8.setObjectName("label_8")
self.formLayout.setWidget(0, QtWidgets.QFormLayout.LabelRole, self.label_8)
self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget2)
self.lineEdit.setObjectName("lineEdit")
self.formLayout.setWidget(0, QtWidgets.QFormLayout.FieldRole, self.lineEdit)
self.label_6 = QtWidgets.QLabel(self.layoutWidget2)
self.label_6.setObjectName("label_6")
self.formLayout.setWidget(1, QtWidgets.QFormLayout.LabelRole, self.label_6)
self.pushButton_4 = QtWidgets.QPushButton(self.layoutWidget2)
self.pushButton_4.setObjectName("pushButton_4")
self.formLayout.setWidget(1, QtWidgets.QFormLayout.FieldRole, self.pushButton_4)
self.label_10 = QtWidgets.QLabel(self.layoutWidget2)
self.label_10.setAlignment(QtCore.Qt.AlignCenter)
self.label_10.setObjectName("label_10")
self.formLayout.setWidget(2, QtWidgets.QFormLayout.LabelRole, self.label_10)
self.pushButton_6 = QtWidgets.QPushButton(self.layoutWidget2)
self.pushButton_6.setObjectName("pushButton_6")
self.formLayout.setWidget(2, QtWidgets.QFormLayout.FieldRole, self.pushButton_6)
self.label_11 = QtWidgets.QLabel(self.layoutWidget2)
self.label_11.setObjectName("label_11")
self.formLayout.setWidget(3, QtWidgets.QFormLayout.LabelRole, self.label_11)
self.pushButton_7 = QtWidgets.QPushButton(self.layoutWidget2)
self.pushButton_7.setObjectName("pushButton_7")
self.formLayout.setWidget(3, QtWidgets.QFormLayout.FieldRole, self.pushButton_7)
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
self.menubar.setGeometry(QtCore.QRect(0, 0, 498, 23))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.label.setText(_translate("MainWindow", "选择训练集文件夹"))
self.label_9.setText(_translate("MainWindow", "选择训练数据保存位置"))
self.label_2.setText(_translate("MainWindow", "选择待识别的图片"))
self.pushButton.setText(_translate("MainWindow", "打开"))
self.pushButton_5.setText(_translate("MainWindow", "打开"))
self.pushButton_3.setText(_translate("MainWindow", "打开"))
self.pushButton_2.setText(_translate("MainWindow", "开始训练"))
self.label_4.setText(_translate("MainWindow", "类别"))
self.label_8.setText(_translate("MainWindow", "训练次数"))
self.label_6.setText(_translate("MainWindow", "加载训练数据"))
self.pushButton_4.setText(_translate("MainWindow", "打开"))
self.label_10.setText(_translate("MainWindow", "打开分类文件"))
self.pushButton_6.setText(_translate("MainWindow", "打开"))
self.label_11.setText(_translate("MainWindow", "打开验证集"))
self.pushButton_7.setText(_translate("MainWindow", "打开"))