这道题作为了一门课程期末的考试题, 但因时间及硬件条件的限制,只是简单的搭网络、做预测,对模型精度并没有进行细调。
一、数据增强(Data Augmentation)
#!/bin/python3
import tensorflow as tf
import matplotlib.pyplot as plt
import os, random
def Left_Right(img):
return tf.image.random_flip_left_right(img)
def Up_Down(img):
return tf.image.random_flip_up_down(img)
def Transpose(img):
return tf.image.transpose_image(img)
def Brightness(img, max_delta):
return tf.image.random_brightness(img, max_delta)
def Contrast(img, lb, ub):
return tf.image.random_contrast(img, lb, ub)
if __name__ == '__main__':
path = '/mnt/windows_E/DeepLearningClass/Data'
class_ = os.listdir(path)
tf.set_random_seed(217)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(len(class_)):
print(class_[i])
imgs = os.listdir(path + '/' + class_[i])
random.shuffle(imgs)
for j in range(0, 15):
print(j)
img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
img = tf.image.decode_png(img, 3)
img = Left_Right(img)
img = tf.image.encode_png(img)
with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_LR_' + imgs[j], 'wb') as writer:
writer.write(img.eval())
for j in range(15, 30):
print(j)
img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
img = tf.image.decode_png(img, 3)
img = Up_Down(img)
img = tf.image.encode_png(img)
with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_UD_' + imgs[j], 'wb') as writer:
writer.write(img.eval())
for j in range(30, 45):
print(j)
img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
img = tf.image.decode_png(img, 3)
img = Transpose(img)
img = tf.image.encode_png(img)
with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_TR_' + imgs[j], 'wb') as writer:
writer.write(img.eval())
for j in range(45, 60):
print(j)
img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
img = tf.image.decode_png(img, 3)
img = Brightness(img, 0.2)
img = tf.image.encode_png(img)
with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_BR_' + imgs[j], 'wb') as writer:
writer.write(img.eval())
'''
for j in range(48, 60):
print(j)
img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
img = tf.image.decode_png(img, 3)
img = Contrast(img, 0.1, 0.6)
img = tf.image.encode_png(img)
with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_CO_' + imgs[j], 'wb') as writer:
writer.write(img.eval())
'''
二、制作TFRecord
#!/bin/python3
import tensorflow as tf
import os
import random
def _int64_feature(v):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [v]))
def _bytes_feature(v):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [v]))
def Write_TFRecords(path, output):
class0 = os.listdir(path[0])
class1 = os.listdir(path[1])
class0.sort()
class1.sort()
photos = []
label_ = {}
for i in range(len(class0)):
label_[class0[i]] = i
pics = os.listdir(path[0] + class0[i] + '/')
for j in range(len(pics)):
photos.append(path[0] + '#' + class0[i] + '#' + pics[j])
for i in range(len(class1)):
pics = os.listdir(path[1] + class1[i] + '/')
for j in range(len(pics)):
photos.append(path[1] + '#' + class1[i] + '#' + pics[j])
random.shuffle(photos)
with open('Valid_Augmentation_120.txt') as f:
con = f.read()
valid = con.splitlines()
photos = list(set(photos) - set(valid))
with open('HALF.txt') as f:
con = f.read()
h = con.splitlines()
photos = list(set(photos) - set(h))
#half = len(photos) // 2
fw = open('HALF2.txt', 'w')
writer = tf.python_io.TFRecordWriter(output)
#for i in range(half):
for i in range(len(photos)):
fw.write(photos[i] + '\n')
p, label, pic = photos[i].split('#')
print(i, p, label, pic)
pic_ = p + label + '/' + pic
image = tf.gfile.FastGFile(pic_, 'rb')
img_str = image.read()
height, width, channels = tf.image.decode_png(img_str).eval().shape
example = tf.train.Example(features = tf.train.Features(feature = \
{'image': _bytes_feature(img_str), 'label': _int64_feature(label_[label]), \
'name': _bytes_feature(bytes(pic, encoding = 'utf8'))}))
writer.write(example.SerializeToString())
writer.close()
fw.close()
if __name__ == '__main__':
path = ['/mnt/windows_E/DeepLearningClass/Data/', './DataAugmentation/']
output = './Weed_InputData_Training_Augmentation.tfrecords02'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
Write_TFRecords(path, output)
三、读取TFRecord、数据预处理
#!/bin/python3
import tensorflow as tf
import matplotlib.pyplot as plt
def Read_Test_TFRecords(files, resize):
files = tf.train.match_filenames_once(files)
filename_queue = tf.train.string_input_producer(files, shuffle = False)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = \
{'image': tf.FixedLenFeature([], tf.string), 'img_name': tf.FixedLenFeature([], tf.string)})
image = tf.image.decode_png(features['image'], 3)
image = tf.image.convert_image_dtype(image, dtype = tf.float32)
image = tf.image.resize_images(image, [resize, resize])
name = features['img_name']
return image, name
def Read_TFRecords(files):
files = tf.train.match_filenames_once(files)
filename_queue = tf.train.string_input_producer(files, shuffle = True)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = \
{'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)})
'''
features = tf.parse_single_example(serialized_example, features = \
{'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([1], tf.int64), \
'height': tf.FixedLenFeature([1], tf.int64), 'width': tf.FixedLenFeature([1], tf.int64), \
'channels': tf.FixedLenFeature([1], tf.int64)})
decode_png时用RGB模式(数字3)
之前是想从tfrecord中读取height/width/channels然后用reshape,
但是总是有错
'''
image = tf.image.decode_png(features['image'], 3)
label = tf.cast(features['label'], tf.int32)
return image, label
def Preprocess(image, label, resize):
image = tf.image.convert_image_dtype(image, dtype = tf.float32)
image = tf.image.resize_images(image, [resize, resize])
label = tf.one_hot(label, 12, 1., 0., dtype = tf.float32)
return image, label
if __name__ == '__main__':
batch_size = 10
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * batch_size
num_threads = 2
image, label = Read_TFRecords('Weed_InputData_Valid*')
image_p, label_p = Preprocess(image, label, 256)
image_batch, label_batch = tf.train.batch([image_p, label_p], batch_size = batch_size)
#image_batch, label_batch = tf.train.shuffle_batch([image_p, label_p], batch_size = batch_size, capacity = capacity, min_after_dequeue = min_after_dequeue, num_threads = num_threads)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
for i in range(12):
'''
a, b = sess.run([image_batch, label_batch])
plt.imshow(a[0])
plt.show()
'''
a, b = sess.run([image_batch, label_batch])
print(b)
coord.request_stop()
coord.join(threads)
四、模型训练
#!/bin/python3
import time
import tensorflow as tf
from Read_Data import *
def Conv_layer(in_, name, kh, kw, n_out, dh, dw, padding):
n_in = in_.get_shape()[-1].value
w = tf.get_variable('w_' + name + str(n_out), shape = [kh, kw, n_in, n_out], dtype = tf.float32, initializer = tf.contrib.layers.xavier_initializer_conv2d())
b = tf.get_variable('b_' + name + str(n_out), shape = [n_out], dtype = tf.float32, initializer = tf.constant_initializer(0.0))
conv = tf.nn.conv2d(input = in_, filter = w, strides = [1, dh, dw, 1], padding = padding)
out_ = tf.nn.relu(tf.nn.bias_add(conv, b))
return out_
def Maxpooling_layer(in_, kh, kw, dh, dw, padding):
return tf.nn.max_pool(in_, ksize = [1, kh, kw, 1], strides = [1, dh, dw, 1], padding = padding)
def Fullc_layer(in_, name, n_out):
n_in = in_.get_shape()[-1].value
w = tf.get_variable('w_' + name + str(n_out), shape=[n_in, n_out], dtype = tf.float32, initializer = tf.contrib.layers.xavier_initializer())
b = tf.get_variable('b_' + name + str(n_out), shape = [n_out], dtype = tf.float32, initializer = tf.constant_initializer(0.1))
return tf.nn.relu_layer(in_, w, b)
def Model(x_, keep_prob):
conv1_1 = Conv_layer(x_, 'conv1_1', 5, 5, 32, 3, 3, 'SAME')
pool1 = Maxpooling_layer(conv1_1, 3, 3, 2, 2, 'SAME')
conv2_1 = Conv_layer(pool1, 'conv2_1', 5, 5, 64, 3, 3, 'SAME')
pool2 = Maxpooling_layer(conv2_1, 3, 3, 2, 2, 'SAME')
pool_shape = pool2.get_shape().as_list()
nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
reshaped = tf.reshape(pool2, [-1, nodes])
fc1 = Fullc_layer(reshaped, 'fc1', 256)
fc2 = Fullc_layer(fc1, 'fc2', 256)
fc2_drop = tf.nn.dropout(fc2, keep_prob)
fc3 = Fullc_layer(fc2_drop, 'fc3', 12)
logits = fc3
return logits
if __name__ == '__main__':
IMAGE_SIZE = 256
IMAGE_CHANNELS = 3
BATCH_SIZE = 32
MIN_AFTER_DEQUEUE = 500
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
NUM_THREADS = 4
keep_prob = tf.placeholder(tf.float32, name = "keep_prob")
#image, label = Read_TFRecords('Weed_InputData_Training.tfrecords*')
image, label = Read_TFRecords('Weed_InputData_Training_Augmentation*')
#image, label = Read_TFRecords('Weed_InputData_Augmentation*')
image_p, label_p = Preprocess(image, label, IMAGE_SIZE)
image_batch, label_batch = tf.train.shuffle_batch([image_p, label_p], batch_size = BATCH_SIZE, capacity = CAPACITY, min_after_dequeue = MIN_AFTER_DEQUEUE, num_threads = NUM_THREADS)
logits = Model(image_batch, keep_prob)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels = label_batch, logits = logits)
loss = tf.reduce_mean(cross_entropy)
train_step = tf.train.AdamOptimizer(1e-3).minimize(loss)
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label_batch, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
with tf.Session() as sess:
step = 0
c_acc = 0
start = time.time()
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
while True:
sess.run(train_step, feed_dict = {keep_prob: 0.5})
if step % 10 == 0:
valid_loss, valid_accuracy = sess.run([loss, accuracy], feed_dict = {keep_prob: 1.0})
print('Step {}, loss = {:.6f}, training accuracy = {:.6f}, total time = {:.3f}'.format(step, valid_loss, valid_accuracy, time.time() - start))
if valid_accuracy == 1.0:
c_acc += 1
if c_acc == 10:
saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
elif c_acc == 12:
saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
elif c_acc == 13:
saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
elif c_acc == 15:
saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
elif c_acc == 18:
saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
elif c_acc > 18:
break
step += 1
coord.request_stop()
coord.join(threads)
五、模型验证(自行划分的120张图片作为验证集)
#!/bin/python3
import tensorflow as tf
import train_nn_TrainingAugmentation
import train_nn_Training
import train_nn
import Read_Data
IMAGE_SIZE = 256
BATCH_SIZE = 20
#img_valid, img_label = Read_Data.Read_TFRecords('Weed_InputData_Valid_120.tfre*')
img_valid, img_label = Read_Data.Read_TFRecords('Weed_InputData_Valid_Augmentation*')
img_p, label_p = Read_Data.Preprocess(img_valid, img_label, IMAGE_SIZE)
img_valid_batch, img_label_batch = tf.train.batch([img_p, label_p], batch_size = BATCH_SIZE)
keep_prob = tf.placeholder(tf.float32)
#logits = train_nn.Model(img_valid_batch, keep_prob)
#logits = train_nn_Training.Model(img_valid_batch, keep_prob)
logits = train_nn_TrainingAugmentation.Model(img_valid_batch, keep_prob)
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(img_label_batch, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
# BEST Model_Saver06_Final_Augmentation
saver = tf.train.import_meta_graph('./Model_Saver06_Final_Augmentation/model_save.ckpt-13.meta')
saver.restore(sess, './Model_Saver06_Final_Augmentation/model_save.ckpt-13')
#saver = tf.train.import_meta_graph('./Model_Saver05_Final_Augmentation/model_save.ckpt-15.meta')
#saver.restore(sess, './Model_Saver05_Final_Augmentation/model_save.ckpt-15')
#saver = tf.train.import_meta_graph('./Model_Saver04_TrainingDataAugmentation/model_save.ckpt-18.meta')
#saver.restore(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt-18')
#saver = tf.train.import_meta_graph('./Model_Saver01/model_save.ckpt.meta')
#saver.restore(sess, './Model_Saver01/model_save.ckpt')
#saver = tf.train.import_meta_graph('./Model_Saver03_DataAugmentation/model_save.ckpt.meta')
#saver.restore(sess, './Model_Saver03_DataAugmentation/model_save.ckpt')
#saver = tf.train.import_meta_graph('./Model_Saver02_Training/model_save.ckpt-15.meta')
#saver.restore(sess, './Model_Saver02_Training/model_save.ckpt-15')
A = 0.0
for i in range(6):
ans = sess.run(accuracy, feed_dict = {keep_prob: 1.0})
print(ans)
A += ans
print('Mean: ', A / 6)
coord.request_stop()
coord.join(threads)
六、模型预测
#!/bin/python3
import tensorflow as tf
import os
def _int64_feature(v):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [v]))
def _bytes_feature(v):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [v]))
def Write_TFRecords(path, output):
img = os.listdir(path)
img.sort()
writer = tf.python_io.TFRecordWriter(output)
for i in range(len(img)):
print(i, img[i])
image = tf.gfile.FastGFile(path + img[i], 'rb')
img_str = image.read()
height, width, channels = tf.image.decode_png(img_str).eval().shape
example = tf.train.Example(features = tf.train.Features(feature = \
{'image': _bytes_feature(img_str), 'img_name': _bytes_feature(bytes(img[i], encoding = 'utf8'))}))
writer.write(example.SerializeToString())
writer.close()
if __name__ == '__main__':
path = '/mnt/windows_E/DeepLearningClass/TestData/'
output = './Weed_InputData_Final_Test.tfrecords'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
Write_TFRecords(path, output)
#!/bin/python3
import tensorflow as tf
import train_nn
import train_nn_TrainingAugmentation
import Read_Data
IMAGE_SIZE = 256
BATCH_SIZE = 6
with open('classes.txt') as f:
con = f.read()
class_ = con.splitlines()
label = {}
for i in range(len(class_)):
a, b = class_[i].split('#')
label[int(a)] = b
print(label)
img_test, img_name = Read_Data.Read_Test_TFRecords('Weed_InputData_Final_Test*', IMAGE_SIZE)
img_test_batch, img_name_batch = tf.train.batch([img_test, img_name], batch_size = BATCH_SIZE)
keep_prob = tf.placeholder(tf.float32)
#logits = train_nn_TrainingAugmentation.Model(img_test_batch, keep_prob)
logits = train_nn.Model(img_test_batch, keep_prob)
pred = tf.argmax(tf.nn.softmax(logits), 1)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
# results04
saver = tf.train.import_meta_graph('./Model_Saver03_DataAugmentation/model_save.ckpt.meta')
saver.restore(sess, './Model_Saver03_DataAugmentation/model_save.ckpt')
# results03
#saver = tf.train.import_meta_graph('./Model_Saver01/model_save.ckpt.meta')
#saver.restore(sess, './Model_Saver01/model_save.ckpt')
# results02
#saver = tf.train.import_meta_graph('./Model_Saver05_Final_Augmentation/model_save.ckpt-15.meta')
#saver.restore(sess, './Model_Saver05_Final_Augmentation/model_save.ckpt-15')
# results01
#saver = tf.train.import_meta_graph('./Model_Saver06_Final_Augmentation/model_save.ckpt-13.meta')
#saver.restore(sess, './Model_Saver06_Final_Augmentation/model_save.ckpt-13')
with open('TestData_results04.csv', 'w') as fw:
fw.write(',species\n')
for i in range(79):
print('----------------------------------------')
#a = sess.run(img_name_batch)
ans = sess.run(pred, feed_dict = {keep_prob: 1.0})
#print(img_name_batch.eval())
for _ in range(len(ans)):
fw.write(',' + label[ans[_]] + '\n')
#print(a)
# 以下两个eval()每运行一下 就会取出来一个batch
#print(img_name_batch.eval())
#print(img_name_batch.eval())
coord.request_stop()
coord.join(threads)