LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)

为了探究更多网络图像分类的效果,尝试LSTM网络处理,顺便谈一谈对循环神经网络的简单理解。最终效果:7M模型85%准确率,单层网络。对比之间做的CNN效果(7M模型,95%准确率,但存在过拟合问题),文章链接https://blog.csdn.net/qq_36187544/article/details/90669462(附源代码)

目录

项目源码百度云

循环神经网络粗浅理解

调参

tensorboard展示

源代码


项目源码百度云

注:图片都是经过预处理的,统一大小,不然会报错!图像处理文件路径可以参考上面的CNN网络链接

链接:https://pan.baidu.com/s/1h0pKo5-p-JDPtM-iUs84_Q 
提取码:j44p 
LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)_第1张图片LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)_第2张图片

models,logs 两个文件夹用于存放模型文件和日志文件,现均为空,带上文件夹让程序可以直接运行
data 数据文件夹,详细图参考上右图,分为7类,每类下有图片。为了防止数据外泄,只在lh1中放了一张图片,可以查看图片是何样
setting.py 配置文件
rnn_train.py 网络训练文件,主文件

循环神经网络粗浅理解

百度一搜各种LSTM,RNN详解,这里只简单说一下:

RNN说白了就是序列化,以28×28图片为例,生成28个CELL,最后对output[28]输出处理一下即可:

LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)_第3张图片

所以,对于RGB彩图,先看代码,再说下原理,网络框架部分代码:

def rnn_graph(x, rnn_size, out_size, width, height, channel):
    '''
    循环神经网络计算图
    :param x:输入数据
    :param rnn_size:
    :param out_size:
    :param width:
    :param height:
    :return:
    '''
    # 权重及偏置
    w = weight_variable([rnn_size, out_size])
    b = bias_variable([out_size])
    # LSTM
    # rnn_size这里指BasicLSTMCell的num_units,指输出的向量维度
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
    # transpose的作用将(?,32,448,3)形状转为(32,?,448,3),?为batch-size,32为高,448为宽,3为通道数(彩图)
    # 准备划分为32个相同网络,输入序列为(448,3),这样速度较快,逻辑较为符合一般思维
    x = tf.transpose(x, [1,0,2,3])
    # reshape -1 代表自适应,这里按照图像每一列的长度为reshape后的列长度
    x = tf.reshape(x, [-1, channel*width])
    # split默任在第一维即0 dimension进行分割,分割成height份,这里实际指把所有图片向量按对应行号进行重组
    x = tf.split(x, height)
    # 这里RNN会有与输入层相同数量的输出层,我们只需要最后一个输出
    outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)
    y_conv = tf.add(tf.matmul(outputs[-1], w), b)
    return y_conv

(32,?,448,3)格式的数据传入网络目的:分为32个cell,每个序列对应448*3,即3色的横向条状序列!

如果格式转为以竖向条状序列更改可如下,这样做网络将很大:

# x = tf.transpose(x, [1,0,2,3])
# x = tf.reshape(x, [-1, channel*width])
# x = tf.split(x, height)
x = tf.transpose(x, [2,0,1,3])
x = tf.reshape(x, [-1, channel*height])
x = tf.split(x, width)

如果调整为3个cell,每个原色的图作为一个输入也是同理!


调参

1.batch-size,很重要,合适的batch-size才能收敛合适,https://blog.csdn.net/qq_36187544/article/details/90478051

2.学习率:

LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)_第4张图片

3.序列多少?RNN网络的核心思想之一是前后序列有关,所以考虑一张长方形图片分为横条和竖条效果是不是不一样?后发现基本一样。。。。。。那就采用小序列进行训练,这样可以加快训练速度

4.RNN中num_units参数,越大学习到的特征越多,准确率提升,相当于增宽神经网络

5.没有尝试加深网络,单层测试,准确率85%


tensorboard展示

数据流图:

LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)_第5张图片

损失和准确率:

LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)_第6张图片LSTM处理图像分类(RGB彩图,自训练长条图,百度云源码,循环神经网络)_第7张图片


源代码

rnn_train.py源代码:

import os
import tensorflow as tf
from time import time
import numpy as np
from LSTM.setting import batch_size, width, height, rnn_size, out_size, channel, learning_rate, num_epoch

'''
训练主函数
tensorboard --logdir=D:\python\LSTM\logs
'''

def weight_variable(shape, w_alpha=0.01):
    '''
    增加噪音,随机生成权重
    :param shape: 权重形状
    :param w_alpha:随机噪声
    :return:
    '''
    initial = w_alpha * tf.random_normal(shape)
    return tf.Variable(initial)
def bias_variable(shape, b_alpha=0.1):
    '''
    增加噪音,随机生成偏置项
    :param shape:权重形状
    :param b_alpha:随机噪声
    :return:
    '''
    initial = b_alpha * tf.random_normal(shape)
    return tf.Variable(initial)
def rnn_graph(x, rnn_size, out_size, width, height, channel):
    '''
    循环神经网络计算图
    :param x:输入数据
    :param rnn_size:
    :param out_size:
    :param width:
    :param height:
    :return:
    '''
    # 权重及偏置
    w = weight_variable([rnn_size, out_size])
    b = bias_variable([out_size])
    # LSTM
    # rnn_size这里指BasicLSTMCell的num_units,指输出的向量维度
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
    # transpose的作用将(?,32,448,3)形状转为(32,?,448,3),?为batch-size,32为高,448为宽,3为通道数(彩图)
    # 准备划分为32个相同网络,输入序列为(448,3),这样速度较快,逻辑较为符合一般思维
    x = tf.transpose(x, [1,0,2,3])
    # reshape -1 代表自适应,这里按照图像每一列的长度为reshape后的列长度
    x = tf.reshape(x, [-1, channel*width])
    # split默任在第一维即0 dimension进行分割,分割成height份,这里实际指把所有图片向量按对应行号进行重组
    x = tf.split(x, height)
    # 这里RNN会有与输入层相同数量的输出层,我们只需要最后一个输出
    outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)
    y_conv = tf.add(tf.matmul(outputs[-1], w), b)
    return y_conv

def accuracy_graph(y, y_conv):
    '''
    偏差计算图
    :param y:
    :param y_conv:
    :return:
    '''
    correct = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    return accuracy

def get_batch(image_list,label_list,img_width,img_height,batch_size,capacity,channel):
    '''
    #通过读取列表来载入批量图片及标签
    :param image_list: 图片路径list
    :param label_list: 标签list
    :param img_width: 图片宽度
    :param img_height: 图片高度
    :param batch_size:
    :param capacity:
    :return:
    '''
    image = tf.cast(image_list,tf.string)
    label = tf.cast(label_list,tf.int32)
    input_queue = tf.train.slice_input_producer([image,label],shuffle=True)
    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0])

    image = tf.image.decode_jpeg(image_contents,channels=channel)
    image = tf.cast(image,tf.float32)
    if channel==3:
        image -= [42.79902,42.79902,42.79902] # 减均值
    elif channel == 1:
        image -= 42.79902  # 减均值
    image.set_shape((img_height,img_width,channel))
    image_batch,label_batch = tf.train.batch([image,label],batch_size=batch_size,num_threads=64,capacity=capacity)
    label_batch = tf.reshape(label_batch,[batch_size])

    return image_batch,label_batch

def get_file(file_dir):
    '''
    通过文件路径获取图片路径及标签
    :param file_dir: 文件路径
    :return:
    '''
    images = []
    for root,sub_folders,files in os.walk(file_dir):
        for name in files:
            images.append(os.path.join(root,name))
    labels = []
    for label_name in images:
        letter = label_name.split("\\")[-2]
        if letter =="lh1":labels.append(0)
        elif letter =="lh2":labels.append(1)
        elif letter == "lh3":labels.append(2)
        elif letter == "lh4":labels.append(3)
        elif letter == "lh5":labels.append(4)
        elif letter == "lh6":labels.append(5)
        elif letter == "lh7":
            labels.append(6)

    print("check for get_file:",images[0],"label is ",labels[0])
    #shuffle
    temp = np.array([images,labels])
    temp = temp.transpose()
    np.random.shuffle(temp)
    image_list = list(temp[:,0])
    label_list = list(temp[:,1])
    label_list = [int(float(i)) for i in label_list]
    return image_list,label_list

#标签格式重构
def onehot(labels):
    n_sample = len(labels)
    n_class = 7  # max(labels) + 1
    onehot_labels = np.zeros((n_sample,n_class))
    onehot_labels[np.arange(n_sample),labels] = 1
    return onehot_labels

if __name__ == '__main__':
    startTime = time()
    # 按照图片大小申请占位符
    x = tf.placeholder(tf.float32, [None, height, width, channel])
    y = tf.placeholder(tf.float32)
    # rnn模型
    y_conv = rnn_graph(x, rnn_size, out_size, width, height, channel)
    # 独热编码转化
    y_conv_prediction = tf.argmax(y_conv, 1)
    y_real = tf.argmax(y, 1)
    # 优化计算图
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_conv, labels=y))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    # 偏差
    accuracy = accuracy_graph(y, y_conv)
    # 自训练图像
    xs, ys = get_file('./data/train1')  # 获取图像列表与标签列表
    image_batch, label_batch = get_batch(xs, ys, img_width=width, img_height=height, batch_size=batch_size, capacity=256,channel=channel)
    # 验证集
    xs_val, ys_val = get_file('./data/test1')  # 获取图像列表与标签列表
    image_val_batch, label_val_batch = get_batch(xs_val, ys_val, img_width=width, img_height=height,batch_size=455, capacity=256,channel=channel)
    # 启动会话.开始训练
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    # 启动线程
    coord = tf.train.Coordinator()  # 使用协调器管理线程
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    # 日志记录
    summary_writer = tf.summary.FileWriter('./logs/', graph=sess.graph, flush_secs=15)
    summary_writer2 = tf.summary.FileWriter('./logs/plot2/', flush_secs=15)
    tf.summary.scalar(name='loss_func', tensor=loss)
    tf.summary.scalar(name='accuracy', tensor=accuracy)
    merged_summary_op = tf.summary.merge_all()

    step = 0
    acc_rate = 0.98
    epoch_start_time = time()
    for i in range(num_epoch):
        batch_x, batch_y = sess.run([image_batch, label_batch])
        batch_y = onehot(batch_y)

        merged_summary,_,loss_show = sess.run([merged_summary_op,optimizer,loss], feed_dict={x: batch_x, y: batch_y})
        summary_writer.add_summary(merged_summary, global_step=i)

        if i % (int(7000//batch_size)) == 0:
            batch_x_test, batch_y_test = sess.run([image_val_batch, label_val_batch])
            batch_y_test = onehot(batch_y_test)
            batch_x_test = batch_x_test.reshape([-1, height, width, channel])
            merged_summary_val,acc,prediction_val_out,real_val_out,loss_show = sess.run([merged_summary_op,accuracy,y_conv_prediction,y_real,loss],feed_dict={x: batch_x_test, y: batch_y_test})
            summary_writer2.add_summary(merged_summary_val, global_step=i)

            # 输出每个类别正确率
            lh1_right, lh2_right, lh3_right, lh4_right, lh5_right, lh6_right, lh7_right = 0, 0, 0, 0, 0, 0, 0
            lh1_wrong, lh2_wrong, lh3_wrong, lh4_wrong, lh5_wrong, lh6_wrong, lh7_wrong = 0, 0, 0, 0, 0, 0, 0
            for ii in range(len(prediction_val_out)):
                if prediction_val_out[ii] == real_val_out[ii]:
                    if real_val_out[ii] == 0:
                        lh1_right += 1
                    elif real_val_out[ii] == 1:
                        lh2_right += 1
                    elif real_val_out[ii] == 2:
                        lh3_right += 1
                    elif real_val_out[ii] == 3:
                        lh4_right += 1
                    elif real_val_out[ii] == 4:
                        lh5_right += 1
                    elif real_val_out[ii] == 5:
                        lh6_right += 1
                    elif real_val_out[ii] == 6:
                        lh7_right += 1
                else:
                    if real_val_out[ii] == 0:
                        lh1_wrong += 1
                    elif real_val_out[ii] == 1:
                        lh2_wrong += 1
                    elif real_val_out[ii] == 2:
                        lh3_wrong += 1
                    elif real_val_out[ii] == 3:
                        lh4_wrong += 1
                    elif real_val_out[ii] == 4:
                        lh5_wrong += 1
                    elif real_val_out[ii] == 5:
                        lh6_wrong += 1
                    elif real_val_out[ii] == 6:
                        lh7_wrong += 1
            print(step, "correct rate :", ((lh1_right) / (lh1_right + lh1_wrong)), ((lh2_right) / (lh2_right + lh2_wrong)),
                  ((lh3_right) / (lh3_right + lh3_wrong)), ((lh4_right) / (lh4_right + lh4_wrong)),
                  ((lh5_right) / (lh5_right + lh5_wrong)), ((lh6_right) / (lh6_right + lh6_wrong)),
                  ((lh7_right) / (lh7_right + lh7_wrong)))
            print(step, "准确的估计准确率为",(((lh1_right) / (lh1_right + lh1_wrong))+((lh2_right) / (lh2_right + lh2_wrong))+
                  ((lh3_right) / (lh3_right + lh3_wrong))+((lh4_right) / (lh4_right + lh4_wrong))+
                  ((lh5_right) / (lh5_right + lh5_wrong))+((lh6_right) / (lh6_right + lh6_wrong))+
                  ((lh7_right) / (lh7_right + lh7_wrong)))/7)


            epoch_end_time = time()
            print("takes time:",(epoch_end_time-epoch_start_time), ' step:', step, ' accuracy:', acc," loss_fun:",loss_show)
            epoch_start_time = epoch_end_time
            # 偏差满足要求,保存模型
            if acc >= acc_rate:
                model_path = os.getcwd() + os.sep + '\models\\'+str(acc_rate) + "LSTM.model"
                saver.save(sess, model_path, global_step=step)
                break
            if step % 10 == 0 and step != 0:
                model_path = os.getcwd() + os.sep + '\models\\'  + str(acc_rate)+ "LSTM"+str(step)+".model"
                print(model_path)
                saver.save(sess, model_path, global_step=step)
            step += 1

    duration = time() - startTime
    print("total takes time:",duration)
    summary_writer.close()

    coord.request_stop()  # 通知线程关闭
    coord.join(threads)  # 等其他线程关闭这一函数才返回


 

你可能感兴趣的:(#,视觉相关网络)