迁移学习后续——中草药分类(inception-v3)

这里应用前一篇‘简易迁移学习’应用inception-v3,对官方模板进行了一个简单改写,大致内容都在,对其进行一个简单化,方便个人理解…

数据集
数据集为54类中草药,图片网上爬取的,每类200张左右。
迁移学习后续——中草药分类(inception-v3)_第1张图片
迁移学习后续——中草药分类(inception-v3)_第2张图片
inception-v3模型
在这里插入图片描述
特征向量
迁移学习后续——中草药分类(inception-v3)_第3张图片

全连接训练模型
迁移学习后续——中草药分类(inception-v3)_第4张图片

代码

from tensorflow.python import gfile
import pickle
from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np
import cv2
import os
from tensorflow import gfile
import matplotlib.pyplot as plt
import random

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

input_dir = r'E:\datas\中草药/'
model_dir = r'./model/tensorflow_inception_graph.pb'
output_folder = r'./data/bb'

###########################################################################
# 提取特征并保存
if not os.path.exists(output_folder):
    os.mkdir(output_folder)

    def image_all_path(file):
        all_imgs = []
        all_labels = []

        # 获取全路径
        for i,j in enumerate(os.listdir(file)):
            for img in os.listdir(os.path.join(file,j)):
                all_imgs.append(os.path.join(file,j,img))
                all_labels.append(i)
        return all_imgs,all_labels

    all_imgs,all_labels = image_all_path(input_dir)
    print(all_imgs)
    print(all_labels)

    #  创建一个计算图,通过计算图解析inception_v3
    def load_pretrainder_inception_v3(model_dir):
        with gfile.FastGFile(model_dir,'rb') as f:
            graph_def = tf.GraphDef()                   # 构造一个空的图
            graph_def.ParseFromString(f.read())         # 将计算图读取进来
            _ = tf.import_graph_def(graph_def,name='')  # 将图导入到默认图

    load_pretrainder_inception_v3(model_dir)
    batch_size = 2000

    # 计算每批次batch_size个共有多少个子文件
    num_batch = int(np.ceil(len(all_imgs) / batch_size))

    with tf.Session() as sess :
        # 加载读取这个模型,得到瓶颈层张量
        second_to_last_tensor = sess.graph.get_tensor_by_name('pool_3/_reshape:0')
        # Tensor("pool_3/_reshape:0", shape=(1, 2048), dtype=float32)

        for i in range(num_batch):
            batch_imgs = all_imgs[i*batch_size:(i+1)*batch_size]
            batch_labels = all_labels[i*batch_size:(i+1)*batch_size]
            batch_features = []

            for all_img in batch_imgs:
                img_data = gfile.FastGFile(all_img,'rb').read()     # 读取图片
                # 数据输入所对应的张量   得到图片内容
                beach_feature = sess.run(second_to_last_tensor,feed_dict={'DecodeJpeg/contents:0':img_data})
                batch_features.append(beach_feature)

            print(batch_features)
            print(len(batch_features))

            batch_features = np.vstack(batch_features)
            output_dir = os.path.join(output_folder,'image_features_%d'%i)
            print(output_dir, '保存完毕')
            with gfile.GFile(output_dir,'w') as f:                      # 打开一个文件
                pickle.dump((batch_imgs,batch_features,batch_labels),f)# 将数据保存在文件中
                            # 全路径        图片内容       标签
else:
    ############################################################################
    # 保存训练模型及预测   (应用cifar_10代码框架)
    data_dir = r'./data/bb/'
    print(os.listdir(data_dir))

    # 读取保存好的特征矩阵
    def load_data(filename):
        with gfile.FastGFile(filename, 'rb') as fr:
            data = pickle.load(fr, encoding='bytes')
        return data


    class ImageData:
        def __init__(self, filenames, need_shuffle):
            all_data = []
            all_labels = []
            for filename in filenames:
                data = load_data(filename)
                all_data.append(data[1])
                all_labels.append(data[2])

            self.data = np.vstack(all_data)
            self.labels = np.hstack(all_labels)

            print(self.data.shape)
            print(self.data.shape[0])
            print(self.labels.shape)

            self.indicator = 0
            self.need_shuffle = need_shuffle

            # 切分要在随机前面这样不会随机全部   只打乱训练集   免得测试集一起打乱  后期测试时会有训练集在里面  造成影响
            self.train_data, self.test_data, self.train_labels, self.test_labels = train_test_split(self.data,self.labels,test_size=0.2,random_state=1)

            self.num_exmaples = self.train_data.shape[0]

            if self.need_shuffle:
                self.shuffle_data()

        def shuffle_data(self):
            order = np.random.permutation(self.num_exmaples)
            self.train_data = self.train_data[order]
            self.train_labels = self.train_labels[order]

        def next_batch(self, batch_size):
            end_indicator = self.indicator + batch_size
            if end_indicator > self.num_exmaples:
                if self.need_shuffle:
                    self.shuffle_data()
                    self.indicator = 0
                    end_indicator = batch_size
                else:
                    raise Exception('no more')

            if end_indicator > self.num_exmaples:
                raise Exception('larger')

            batch_data = self.train_data[self.indicator:end_indicator]
            batch_labels = self.train_labels[self.indicator:end_indicator]
            self.indicator = end_indicator
            return batch_data, batch_labels


    data_filenames = [os.path.join(data_dir, "image_features_{num}".format(num=i)) for i in range(0, 7)]
    datas = ImageData(data_filenames, True)

    x = tf.placeholder(tf.float32, [None, 2048])
    y = tf.placeholder(tf.int64, [None])

    # inception_v3    特点47层卷积训练好的模型   我们只用设置一层全连接进行训练
    fc1 = tf.layers.dense(x, 1024, activation=tf.nn.relu)
    h = tf.layers.dense(fc1, 54)

    cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y))

    predict = tf.argmax(tf.nn.softmax(h), 1)
    is_correct = tf.equal(predict, y)
    accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(cost)

    model_path = './bottleneck1/'   #保存训练模型(这个模型是训练全连接的模型)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        # 如果没有保存过就保存
        if not os.path.exists(model_path):
            os.mkdir(model_path)

            for step in range(1, 20001):
                batch_data, batch_labels = datas.next_batch(batch_size=100)
                cost_val, _, acc_val = sess.run([cost, optimizer, accuracy], feed_dict={
                    x: batch_data, y: batch_labels
                })
                if step % 1000 == 0:
                    print('[Train] step:%d,cost:%4.5f,acc:%4.5f' % (step, cost_val, acc_val))
                if step % 2000 == 0:
                    saver.save(sess, model_path, global_step=step)
            print('End of model save')

        #反之进行测试
        else:
            print('Start training model')
            saver.restore(sess, tf.train.latest_checkpoint(model_path))

            test_batch_data, test_batch_labels = datas.test_data, datas.test_labels
            test_acc_val = sess.run(accuracy, feed_dict={x: test_batch_data, y: test_batch_labels})
            print('test_acc_val:', test_acc_val)


############################################################################
# 随机种类(图片预测)
            model_dir = "./model/tensorflow_inception_graph.pb"

            class_names = ['guizhencao', 'hongliao', 'Wormwood', 'xunma', 'Hairyveinagrimony', 'Gardenia','RadixIsatidis', 'gouweibacao', 'plantains', 'tongquancao', 'qigucao','MonochoriaVaginalis(yashecao)', 'commelina_communis', 'xiaoqieyi', 'shuiqincai',
 'Bupleurum(chaihu)', 'dandelions', 'Pinellia(banxia)', 'zhajiangcao', 'Rabdosiaserra',
  'feipeng', 'honeysuckles', 'xiaoji', 'selfheals', 'MorningGlory(qianniuhua)', 'malan','tianhukui', 'ziyunying', 'EichhorniaCrassipes(fengyanlan)', 'heshouwu', 'ChenopodiumAlbum',
'Ophiopogon(maidong)', 'huanghuacai', 'mantuoluo', 'kucai', 'sedum_sarmentosum',
 'bosipopona', 'boheye', 'juaner', 'lotusseed', 'palms', 'cangerzi',
  'Wahlenbergia(lanhuashen)', 'Angelica(baizhi)', 'ginsengs', 'zeqi', 'mangnoliaofficinalis',
  'perillas', 'yichuanhong', 'jicai', 'Odoratum(yuzhu)', 'denglongcao', 'Agastacherugosa','Moneygrass']
            n_class = 54


            def all_dir():
                data_dirs = r'E:\datas\中草药'

                random_list = []
                for i in range(25):
                    data = random.randint(0, 53)
                    random_list.append(data)

                data_folder = os.listdir(data_dirs)

                test_all_dir = []
                for i in random_list:
                    num = random.randint(10, 50)
                    num = str(0) + str(num)
                    data_dir = os.path.join(data_dirs, data_folder[i], data_folder[i] + '_' + str(num) + '.jpg')
                    test_all_dir.append(data_dir)
                return test_all_dir


            def load_google_model(model_dir):
                with gfile.FastGFile(model_dir, "rb") as f:
                    graph_def = tf.GraphDef()
                    graph_def.ParseFromString(f.read())
                    tf.import_graph_def(graph_def, name="")


            def create_test_featrue(test_dir):
                with tf.Session() as sess:
                    st = sess.graph.get_tensor_by_name("pool_3/_reshape:0")
                    test_data, test_feature, test_labels = [], [], []

                    for i in test_dir:
                        print(i)
                        img = plt.imread(i)
                        img = cv2.resize(img, (448, 448))
                        # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                        test_data.append(img)
                        img_data = gfile.FastGFile(i, "rb").read()
                        feature = sess.run(st, feed_dict={"DecodeJpeg/contents:0": img_data})
                        test_feature.append(feature)
                        test_labels.append(i.split("\\")[-1].split("_")[0])
                    print(test_labels)
                return test_data, np.reshape(test_feature, (-1, 2048)), np.array(test_labels)


            def show_img(test_data, pre_labels, test_labels):
                print(pre_labels)
                _, axs = plt.subplots(5, 5)
                for i, axi in enumerate(axs.flat):
                    axi.imshow(test_data[i])
                    print('真实标签:', test_labels[i], '\t预测标签:', pre_labels[i])
                    axi.set_xlabel(xlabel=pre_labels[i], color="black" if pre_labels[i] == test_labels[i] else "red")
                    axi.set(xticks=[], yticks=[])
                plt.show()


            load_google_model(model_dir)
            test_all_dir = all_dir()
            test_data, test_feature, test_labels = create_test_featrue(test_all_dir)

            pred = sess.run(tf.argmax(h, 1), {x: test_feature})
            show_img(test_data, [class_names[i] for i in pred], test_labels)

效果

Start training model

test_acc_val: 0.7917889

E:\datas\中草药\qigucao\qigucao_039.jpg
E:\datas\中草药\Ophiopogon(maidong)\Ophiopogon(maidong)_023.jpg
E:\datas\中草药\cangerzi\cangerzi_041.jpg
E:\datas\中草药\malan\malan_037.jpg
E:\datas\中草药\guizhencao\guizhencao_028.jpg
E:\datas\中草药\EichhorniaCrassipes(fengyanlan)\EichhorniaCrassipes(fengyanlan)_018.jpg
E:\datas\中草药\ginsengs\ginsengs_037.jpg
E:\datas\中草药\tianhukui\tianhukui_032.jpg
E:\datas\中草药\tianhukui\tianhukui_046.jpg
E:\datas\中草药\sedum_sarmentosum\sedum_sarmentosum_043.jpg
E:\datas\中草药\Pinellia(banxia)\Pinellia(banxia)_018.jpg
E:\datas\中草药\denglongcao\denglongcao_049.jpg
E:\datas\中草药\qigucao\qigucao_012.jpg
E:\datas\中草药\sedum_sarmentosum\sedum_sarmentosum_041.jpg
E:\datas\中草药\honeysuckles\honeysuckles_015.jpg
E:\datas\中草药\Wahlenbergia(lanhuashen)\Wahlenbergia(lanhuashen)_029.jpg
E:\datas\中草药\heshouwu\heshouwu_018.jpg
E:\datas\中草药\ginsengs\ginsengs_017.jpg
E:\datas\中草药\boheye\boheye_030.jpg
E:\datas\中草药\malan\malan_010.jpg
E:\datas\中草药\Wormwood\Wormwood_011.jpg
E:\datas\中草药\RadixIsatidis\RadixIsatidis_050.jpg
E:\datas\中草药\huanghuacai\huanghuacai_039.jpg
E:\datas\中草药\juaner\juaner_038.jpg
E:\datas\中草药\lotusseed\lotusseed_015.jpg
['qigucao', 'Ophiopogon(maidong)', 'cangerzi', 'malan', 'guizhencao', 'EichhorniaCrassipes(fengyanlan)', 'ginsengs', 'tianhukui', 'tianhukui', 'sedum', 'Pinellia(banxia)', 'denglongcao', 'qigucao', 'sedum', 'honeysuckles', 'Wahlenbergia(lanhuashen)', 'heshouwu', 'ginsengs', 'boheye', 'malan', 'Wormwood', 'RadixIsatidis', 'huanghuacai', 'juaner', 'lotusseed']
['qigucao', 'Ophiopogon(maidong)', 'cangerzi', 'malan', 'cangerzi', 'EichhorniaCrassipes(fengyanlan)', 'ginsengs', 'tianhukui', 'tianhukui', 'sedum_sarmentosum', 'Pinellia(banxia)', 'denglongcao', 'qigucao', 'sedum_sarmentosum', 'honeysuckles', 'Wahlenbergia(lanhuashen)', 'heshouwu', 'ginsengs', 'boheye', 'malan', 'shuiqincai', 'shuiqincai', 'huanghuacai', 'juaner', 'lotusseed']
真实标签: qigucao 	预测标签: qigucao
真实标签: Ophiopogon(maidong) 	预测标签: Ophiopogon(maidong)
真实标签: cangerzi 	预测标签: cangerzi
真实标签: malan 	预测标签: malan
真实标签: guizhencao 	预测标签: cangerzi
真实标签: EichhorniaCrassipes(fengyanlan) 	预测标签: EichhorniaCrassipes(fengyanlan)
真实标签: ginsengs 	预测标签: ginsengs
真实标签: tianhukui 	预测标签: tianhukui
真实标签: tianhukui 	预测标签: tianhukui
真实标签: sedum 	预测标签: sedum_sarmentosum
真实标签: Pinellia(banxia) 	预测标签: Pinellia(banxia)
真实标签: denglongcao 	预测标签: denglongcao
真实标签: qigucao 	预测标签: qigucao
真实标签: sedum 	预测标签: sedum_sarmentosum
真实标签: honeysuckles 	预测标签: honeysuckles
真实标签: Wahlenbergia(lanhuashen) 	预测标签: Wahlenbergia(lanhuashen)
真实标签: heshouwu 	预测标签: heshouwu
真实标签: ginsengs 	预测标签: ginsengs
真实标签: boheye 	预测标签: boheye
真实标签: malan 	预测标签: malan
真实标签: Wormwood 	预测标签: shuiqincai
真实标签: RadixIsatidis 	预测标签: shuiqincai
真实标签: huanghuacai 	预测标签: huanghuacai
真实标签: juaner 	预测标签: juaner
真实标签: lotusseed 	预测标签: lotusseed

迁移学习后续——中草药分类(inception-v3)_第5张图片
总结
1.图片网上爬取的进行过清洗,但噪声还是太多,中草药有青稞和药材,我们选取的是青稞也就是植物早期的样子,所以数据集并不完美。
2.调参经验不足

你可能感兴趣的:(inception-v3,迁移学习)