这里应用前一篇‘简易迁移学习’应用inception-v3,对官方模板进行了一个简单改写,大致内容都在,对其进行一个简单化,方便个人理解…
数据集
数据集为54类中草药,图片网上爬取的,每类200张左右。
inception-v3模型
特征向量
from tensorflow.python import gfile
import pickle
from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np
import cv2
import os
from tensorflow import gfile
import matplotlib.pyplot as plt
import random
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
input_dir = r'E:\datas\中草药/'
model_dir = r'./model/tensorflow_inception_graph.pb'
output_folder = r'./data/bb'
###########################################################################
# 提取特征并保存
if not os.path.exists(output_folder):
os.mkdir(output_folder)
def image_all_path(file):
all_imgs = []
all_labels = []
# 获取全路径
for i,j in enumerate(os.listdir(file)):
for img in os.listdir(os.path.join(file,j)):
all_imgs.append(os.path.join(file,j,img))
all_labels.append(i)
return all_imgs,all_labels
all_imgs,all_labels = image_all_path(input_dir)
print(all_imgs)
print(all_labels)
# 创建一个计算图,通过计算图解析inception_v3
def load_pretrainder_inception_v3(model_dir):
with gfile.FastGFile(model_dir,'rb') as f:
graph_def = tf.GraphDef() # 构造一个空的图
graph_def.ParseFromString(f.read()) # 将计算图读取进来
_ = tf.import_graph_def(graph_def,name='') # 将图导入到默认图
load_pretrainder_inception_v3(model_dir)
batch_size = 2000
# 计算每批次batch_size个共有多少个子文件
num_batch = int(np.ceil(len(all_imgs) / batch_size))
with tf.Session() as sess :
# 加载读取这个模型,得到瓶颈层张量
second_to_last_tensor = sess.graph.get_tensor_by_name('pool_3/_reshape:0')
# Tensor("pool_3/_reshape:0", shape=(1, 2048), dtype=float32)
for i in range(num_batch):
batch_imgs = all_imgs[i*batch_size:(i+1)*batch_size]
batch_labels = all_labels[i*batch_size:(i+1)*batch_size]
batch_features = []
for all_img in batch_imgs:
img_data = gfile.FastGFile(all_img,'rb').read() # 读取图片
# 数据输入所对应的张量 得到图片内容
beach_feature = sess.run(second_to_last_tensor,feed_dict={'DecodeJpeg/contents:0':img_data})
batch_features.append(beach_feature)
print(batch_features)
print(len(batch_features))
batch_features = np.vstack(batch_features)
output_dir = os.path.join(output_folder,'image_features_%d'%i)
print(output_dir, '保存完毕')
with gfile.GFile(output_dir,'w') as f: # 打开一个文件
pickle.dump((batch_imgs,batch_features,batch_labels),f)# 将数据保存在文件中
# 全路径 图片内容 标签
else:
############################################################################
# 保存训练模型及预测 (应用cifar_10代码框架)
data_dir = r'./data/bb/'
print(os.listdir(data_dir))
# 读取保存好的特征矩阵
def load_data(filename):
with gfile.FastGFile(filename, 'rb') as fr:
data = pickle.load(fr, encoding='bytes')
return data
class ImageData:
def __init__(self, filenames, need_shuffle):
all_data = []
all_labels = []
for filename in filenames:
data = load_data(filename)
all_data.append(data[1])
all_labels.append(data[2])
self.data = np.vstack(all_data)
self.labels = np.hstack(all_labels)
print(self.data.shape)
print(self.data.shape[0])
print(self.labels.shape)
self.indicator = 0
self.need_shuffle = need_shuffle
# 切分要在随机前面这样不会随机全部 只打乱训练集 免得测试集一起打乱 后期测试时会有训练集在里面 造成影响
self.train_data, self.test_data, self.train_labels, self.test_labels = train_test_split(self.data,self.labels,test_size=0.2,random_state=1)
self.num_exmaples = self.train_data.shape[0]
if self.need_shuffle:
self.shuffle_data()
def shuffle_data(self):
order = np.random.permutation(self.num_exmaples)
self.train_data = self.train_data[order]
self.train_labels = self.train_labels[order]
def next_batch(self, batch_size):
end_indicator = self.indicator + batch_size
if end_indicator > self.num_exmaples:
if self.need_shuffle:
self.shuffle_data()
self.indicator = 0
end_indicator = batch_size
else:
raise Exception('no more')
if end_indicator > self.num_exmaples:
raise Exception('larger')
batch_data = self.train_data[self.indicator:end_indicator]
batch_labels = self.train_labels[self.indicator:end_indicator]
self.indicator = end_indicator
return batch_data, batch_labels
data_filenames = [os.path.join(data_dir, "image_features_{num}".format(num=i)) for i in range(0, 7)]
datas = ImageData(data_filenames, True)
x = tf.placeholder(tf.float32, [None, 2048])
y = tf.placeholder(tf.int64, [None])
# inception_v3 特点47层卷积训练好的模型 我们只用设置一层全连接进行训练
fc1 = tf.layers.dense(x, 1024, activation=tf.nn.relu)
h = tf.layers.dense(fc1, 54)
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=h, labels=y))
predict = tf.argmax(tf.nn.softmax(h), 1)
is_correct = tf.equal(predict, y)
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(cost)
model_path = './bottleneck1/' #保存训练模型(这个模型是训练全连接的模型)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
# 如果没有保存过就保存
if not os.path.exists(model_path):
os.mkdir(model_path)
for step in range(1, 20001):
batch_data, batch_labels = datas.next_batch(batch_size=100)
cost_val, _, acc_val = sess.run([cost, optimizer, accuracy], feed_dict={
x: batch_data, y: batch_labels
})
if step % 1000 == 0:
print('[Train] step:%d,cost:%4.5f,acc:%4.5f' % (step, cost_val, acc_val))
if step % 2000 == 0:
saver.save(sess, model_path, global_step=step)
print('End of model save')
#反之进行测试
else:
print('Start training model')
saver.restore(sess, tf.train.latest_checkpoint(model_path))
test_batch_data, test_batch_labels = datas.test_data, datas.test_labels
test_acc_val = sess.run(accuracy, feed_dict={x: test_batch_data, y: test_batch_labels})
print('test_acc_val:', test_acc_val)
############################################################################
# 随机种类(图片预测)
model_dir = "./model/tensorflow_inception_graph.pb"
class_names = ['guizhencao', 'hongliao', 'Wormwood', 'xunma', 'Hairyveinagrimony', 'Gardenia','RadixIsatidis', 'gouweibacao', 'plantains', 'tongquancao', 'qigucao','MonochoriaVaginalis(yashecao)', 'commelina_communis', 'xiaoqieyi', 'shuiqincai',
'Bupleurum(chaihu)', 'dandelions', 'Pinellia(banxia)', 'zhajiangcao', 'Rabdosiaserra',
'feipeng', 'honeysuckles', 'xiaoji', 'selfheals', 'MorningGlory(qianniuhua)', 'malan','tianhukui', 'ziyunying', 'EichhorniaCrassipes(fengyanlan)', 'heshouwu', 'ChenopodiumAlbum',
'Ophiopogon(maidong)', 'huanghuacai', 'mantuoluo', 'kucai', 'sedum_sarmentosum',
'bosipopona', 'boheye', 'juaner', 'lotusseed', 'palms', 'cangerzi',
'Wahlenbergia(lanhuashen)', 'Angelica(baizhi)', 'ginsengs', 'zeqi', 'mangnoliaofficinalis',
'perillas', 'yichuanhong', 'jicai', 'Odoratum(yuzhu)', 'denglongcao', 'Agastacherugosa','Moneygrass']
n_class = 54
def all_dir():
data_dirs = r'E:\datas\中草药'
random_list = []
for i in range(25):
data = random.randint(0, 53)
random_list.append(data)
data_folder = os.listdir(data_dirs)
test_all_dir = []
for i in random_list:
num = random.randint(10, 50)
num = str(0) + str(num)
data_dir = os.path.join(data_dirs, data_folder[i], data_folder[i] + '_' + str(num) + '.jpg')
test_all_dir.append(data_dir)
return test_all_dir
def load_google_model(model_dir):
with gfile.FastGFile(model_dir, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
def create_test_featrue(test_dir):
with tf.Session() as sess:
st = sess.graph.get_tensor_by_name("pool_3/_reshape:0")
test_data, test_feature, test_labels = [], [], []
for i in test_dir:
print(i)
img = plt.imread(i)
img = cv2.resize(img, (448, 448))
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
test_data.append(img)
img_data = gfile.FastGFile(i, "rb").read()
feature = sess.run(st, feed_dict={"DecodeJpeg/contents:0": img_data})
test_feature.append(feature)
test_labels.append(i.split("\\")[-1].split("_")[0])
print(test_labels)
return test_data, np.reshape(test_feature, (-1, 2048)), np.array(test_labels)
def show_img(test_data, pre_labels, test_labels):
print(pre_labels)
_, axs = plt.subplots(5, 5)
for i, axi in enumerate(axs.flat):
axi.imshow(test_data[i])
print('真实标签:', test_labels[i], '\t预测标签:', pre_labels[i])
axi.set_xlabel(xlabel=pre_labels[i], color="black" if pre_labels[i] == test_labels[i] else "red")
axi.set(xticks=[], yticks=[])
plt.show()
load_google_model(model_dir)
test_all_dir = all_dir()
test_data, test_feature, test_labels = create_test_featrue(test_all_dir)
pred = sess.run(tf.argmax(h, 1), {x: test_feature})
show_img(test_data, [class_names[i] for i in pred], test_labels)
效果
Start training model
test_acc_val: 0.7917889
E:\datas\中草药\qigucao\qigucao_039.jpg
E:\datas\中草药\Ophiopogon(maidong)\Ophiopogon(maidong)_023.jpg
E:\datas\中草药\cangerzi\cangerzi_041.jpg
E:\datas\中草药\malan\malan_037.jpg
E:\datas\中草药\guizhencao\guizhencao_028.jpg
E:\datas\中草药\EichhorniaCrassipes(fengyanlan)\EichhorniaCrassipes(fengyanlan)_018.jpg
E:\datas\中草药\ginsengs\ginsengs_037.jpg
E:\datas\中草药\tianhukui\tianhukui_032.jpg
E:\datas\中草药\tianhukui\tianhukui_046.jpg
E:\datas\中草药\sedum_sarmentosum\sedum_sarmentosum_043.jpg
E:\datas\中草药\Pinellia(banxia)\Pinellia(banxia)_018.jpg
E:\datas\中草药\denglongcao\denglongcao_049.jpg
E:\datas\中草药\qigucao\qigucao_012.jpg
E:\datas\中草药\sedum_sarmentosum\sedum_sarmentosum_041.jpg
E:\datas\中草药\honeysuckles\honeysuckles_015.jpg
E:\datas\中草药\Wahlenbergia(lanhuashen)\Wahlenbergia(lanhuashen)_029.jpg
E:\datas\中草药\heshouwu\heshouwu_018.jpg
E:\datas\中草药\ginsengs\ginsengs_017.jpg
E:\datas\中草药\boheye\boheye_030.jpg
E:\datas\中草药\malan\malan_010.jpg
E:\datas\中草药\Wormwood\Wormwood_011.jpg
E:\datas\中草药\RadixIsatidis\RadixIsatidis_050.jpg
E:\datas\中草药\huanghuacai\huanghuacai_039.jpg
E:\datas\中草药\juaner\juaner_038.jpg
E:\datas\中草药\lotusseed\lotusseed_015.jpg
['qigucao', 'Ophiopogon(maidong)', 'cangerzi', 'malan', 'guizhencao', 'EichhorniaCrassipes(fengyanlan)', 'ginsengs', 'tianhukui', 'tianhukui', 'sedum', 'Pinellia(banxia)', 'denglongcao', 'qigucao', 'sedum', 'honeysuckles', 'Wahlenbergia(lanhuashen)', 'heshouwu', 'ginsengs', 'boheye', 'malan', 'Wormwood', 'RadixIsatidis', 'huanghuacai', 'juaner', 'lotusseed']
['qigucao', 'Ophiopogon(maidong)', 'cangerzi', 'malan', 'cangerzi', 'EichhorniaCrassipes(fengyanlan)', 'ginsengs', 'tianhukui', 'tianhukui', 'sedum_sarmentosum', 'Pinellia(banxia)', 'denglongcao', 'qigucao', 'sedum_sarmentosum', 'honeysuckles', 'Wahlenbergia(lanhuashen)', 'heshouwu', 'ginsengs', 'boheye', 'malan', 'shuiqincai', 'shuiqincai', 'huanghuacai', 'juaner', 'lotusseed']
真实标签: qigucao 预测标签: qigucao
真实标签: Ophiopogon(maidong) 预测标签: Ophiopogon(maidong)
真实标签: cangerzi 预测标签: cangerzi
真实标签: malan 预测标签: malan
真实标签: guizhencao 预测标签: cangerzi
真实标签: EichhorniaCrassipes(fengyanlan) 预测标签: EichhorniaCrassipes(fengyanlan)
真实标签: ginsengs 预测标签: ginsengs
真实标签: tianhukui 预测标签: tianhukui
真实标签: tianhukui 预测标签: tianhukui
真实标签: sedum 预测标签: sedum_sarmentosum
真实标签: Pinellia(banxia) 预测标签: Pinellia(banxia)
真实标签: denglongcao 预测标签: denglongcao
真实标签: qigucao 预测标签: qigucao
真实标签: sedum 预测标签: sedum_sarmentosum
真实标签: honeysuckles 预测标签: honeysuckles
真实标签: Wahlenbergia(lanhuashen) 预测标签: Wahlenbergia(lanhuashen)
真实标签: heshouwu 预测标签: heshouwu
真实标签: ginsengs 预测标签: ginsengs
真实标签: boheye 预测标签: boheye
真实标签: malan 预测标签: malan
真实标签: Wormwood 预测标签: shuiqincai
真实标签: RadixIsatidis 预测标签: shuiqincai
真实标签: huanghuacai 预测标签: huanghuacai
真实标签: juaner 预测标签: juaner
真实标签: lotusseed 预测标签: lotusseed
总结
1.图片网上爬取的进行过清洗,但噪声还是太多,中草药有青稞和药材,我们选取的是青稞也就是植物早期的样子,所以数据集并不完美。
2.调参经验不足