用Tensorflow实现植物幼苗图像分类(Kaggle——Plant Seedlings Classification)

这道题作为了一门课程期末的考试题, 但因时间及硬件条件的限制,只是简单的搭网络、做预测,对模型精度并没有进行细调。


一、数据增强(Data Augmentation)

#!/bin/python3
import tensorflow as tf
import matplotlib.pyplot as plt
import os, random

def Left_Right(img):
    return tf.image.random_flip_left_right(img)

def Up_Down(img):
    return tf.image.random_flip_up_down(img)

def Transpose(img):
    return tf.image.transpose_image(img)

def Brightness(img, max_delta):
    return tf.image.random_brightness(img, max_delta)

def Contrast(img, lb, ub):
    return tf.image.random_contrast(img, lb, ub)

if __name__ == '__main__':
    path = '/mnt/windows_E/DeepLearningClass/Data'
    class_ = os.listdir(path)
    tf.set_random_seed(217)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(len(class_)):
            print(class_[i])
            imgs = os.listdir(path + '/' + class_[i])
            random.shuffle(imgs)
            for j in range(0, 15):
                print(j)
                img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
                img = tf.image.decode_png(img, 3)
                img = Left_Right(img)
                img = tf.image.encode_png(img)
                with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_LR_' + imgs[j], 'wb') as writer:
                    writer.write(img.eval())
            for j in range(15, 30):
                print(j)
                img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
                img = tf.image.decode_png(img, 3)
                img = Up_Down(img)
                img = tf.image.encode_png(img)
                with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_UD_' + imgs[j], 'wb') as writer:
                    writer.write(img.eval())
            for j in range(30, 45):
                print(j)
                img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
                img = tf.image.decode_png(img, 3)
                img = Transpose(img)
                img = tf.image.encode_png(img)
                with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_TR_' + imgs[j], 'wb') as writer:
                    writer.write(img.eval())
            for j in range(45, 60):
                print(j)
                img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
                img = tf.image.decode_png(img, 3)
                img = Brightness(img, 0.2)
                img = tf.image.encode_png(img)
                with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_BR_' + imgs[j], 'wb') as writer:
                    writer.write(img.eval())
            '''
            for j in range(48, 60):
                print(j)
                img = tf.gfile.FastGFile(path + '/' + class_[i] + '/' + imgs[j], 'rb').read()
                img = tf.image.decode_png(img, 3)
                img = Contrast(img, 0.1, 0.6)
                img = tf.image.encode_png(img)
                with tf.gfile.FastGFile('./DataAugmentation/' + class_[i] + '/Aug_CO_' + imgs[j], 'wb') as writer:
                    writer.write(img.eval())
            '''

二、制作TFRecord

#!/bin/python3
import tensorflow as tf
import os
import random

def _int64_feature(v):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [v]))

def _bytes_feature(v):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [v]))

def Write_TFRecords(path, output):
    class0 = os.listdir(path[0])
    class1 = os.listdir(path[1])
    class0.sort()
    class1.sort()
    photos = []
    label_ = {}
    for i in range(len(class0)):
        label_[class0[i]] = i
        pics = os.listdir(path[0] + class0[i] + '/')
        for j in range(len(pics)):
            photos.append(path[0] + '#' + class0[i] + '#' + pics[j])
    for i in range(len(class1)):
        pics = os.listdir(path[1] + class1[i] + '/')
        for j in range(len(pics)):
            photos.append(path[1] + '#' + class1[i] + '#' + pics[j])
    random.shuffle(photos)
    with open('Valid_Augmentation_120.txt') as f:
        con = f.read()
    valid = con.splitlines()
    photos = list(set(photos) - set(valid))
    with open('HALF.txt') as f:
        con = f.read()
    h = con.splitlines()
    photos = list(set(photos) - set(h))
    #half = len(photos) // 2
    fw = open('HALF2.txt', 'w')
    writer = tf.python_io.TFRecordWriter(output)
    #for i in range(half):
    for i in range(len(photos)):
        fw.write(photos[i] + '\n')
        p, label, pic = photos[i].split('#')
        print(i, p, label, pic)
        pic_ = p + label + '/' + pic
        image = tf.gfile.FastGFile(pic_, 'rb')
        img_str = image.read()
        height, width, channels = tf.image.decode_png(img_str).eval().shape
        example = tf.train.Example(features = tf.train.Features(feature = \
            {'image': _bytes_feature(img_str), 'label': _int64_feature(label_[label]), \
            'name': _bytes_feature(bytes(pic, encoding = 'utf8'))}))
        writer.write(example.SerializeToString())
    writer.close()
    fw.close()

if __name__ == '__main__':
    path = ['/mnt/windows_E/DeepLearningClass/Data/', './DataAugmentation/']
    output = './Weed_InputData_Training_Augmentation.tfrecords02'
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        Write_TFRecords(path, output)

三、读取TFRecord、数据预处理

#!/bin/python3
import tensorflow as tf
import matplotlib.pyplot as plt

def Read_Test_TFRecords(files, resize):
    files = tf.train.match_filenames_once(files)
    filename_queue = tf.train.string_input_producer(files, shuffle = False)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features = \
        {'image': tf.FixedLenFeature([], tf.string), 'img_name': tf.FixedLenFeature([], tf.string)})
    image = tf.image.decode_png(features['image'], 3)
    image = tf.image.convert_image_dtype(image, dtype = tf.float32)
    image = tf.image.resize_images(image, [resize, resize])
    name = features['img_name']

    return image, name

def Read_TFRecords(files):
    files = tf.train.match_filenames_once(files)
    filename_queue = tf.train.string_input_producer(files, shuffle = True)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features = \
        {'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)})
    '''
    features = tf.parse_single_example(serialized_example, features = \
        {'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([1], tf.int64), \
        'height': tf.FixedLenFeature([1], tf.int64), 'width': tf.FixedLenFeature([1], tf.int64), \
        'channels': tf.FixedLenFeature([1], tf.int64)})
    decode_png时用RGB模式(数字3)
    之前是想从tfrecord中读取height/width/channels然后用reshape,
    但是总是有错
    '''
    image = tf.image.decode_png(features['image'], 3)
    label = tf.cast(features['label'], tf.int32)

    return image, label

def Preprocess(image, label, resize):
    image = tf.image.convert_image_dtype(image, dtype = tf.float32)
    image = tf.image.resize_images(image, [resize, resize])
    label = tf.one_hot(label, 12, 1., 0., dtype = tf.float32)

    return image, label

if __name__ == '__main__':
    batch_size = 10
    min_after_dequeue = 1000
    capacity = min_after_dequeue + 3 * batch_size
    num_threads = 2
    image, label = Read_TFRecords('Weed_InputData_Valid*')
    image_p, label_p = Preprocess(image, label, 256)
    image_batch, label_batch = tf.train.batch([image_p, label_p], batch_size = batch_size)
    #image_batch, label_batch = tf.train.shuffle_batch([image_p, label_p], batch_size = batch_size, capacity = capacity, min_after_dequeue = min_after_dequeue, num_threads = num_threads)
    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess = sess, coord = coord)
        for i in range(12):
            '''
            a, b = sess.run([image_batch, label_batch])
            plt.imshow(a[0])
            plt.show()
            '''
            a, b = sess.run([image_batch, label_batch])
            print(b)
        coord.request_stop()
        coord.join(threads)

四、模型训练

#!/bin/python3
import time
import tensorflow as tf
from Read_Data import *

def Conv_layer(in_, name, kh, kw, n_out, dh, dw, padding):
    n_in = in_.get_shape()[-1].value
    w = tf.get_variable('w_' + name + str(n_out), shape = [kh, kw, n_in, n_out], dtype = tf.float32, initializer = tf.contrib.layers.xavier_initializer_conv2d())
    b = tf.get_variable('b_' + name + str(n_out), shape = [n_out], dtype = tf.float32, initializer = tf.constant_initializer(0.0))
    conv = tf.nn.conv2d(input = in_, filter = w, strides = [1, dh, dw, 1], padding = padding)
    out_ = tf.nn.relu(tf.nn.bias_add(conv, b))
    return out_

def Maxpooling_layer(in_, kh, kw, dh, dw, padding):
    return tf.nn.max_pool(in_, ksize = [1, kh, kw, 1], strides = [1, dh, dw, 1], padding = padding)

def Fullc_layer(in_, name, n_out):
    n_in = in_.get_shape()[-1].value
    w = tf.get_variable('w_' + name + str(n_out), shape=[n_in, n_out], dtype = tf.float32, initializer = tf.contrib.layers.xavier_initializer())
    b = tf.get_variable('b_' + name + str(n_out), shape = [n_out], dtype = tf.float32, initializer = tf.constant_initializer(0.1))
    return tf.nn.relu_layer(in_, w, b)

def Model(x_, keep_prob):

    conv1_1 = Conv_layer(x_, 'conv1_1', 5, 5, 32, 3, 3, 'SAME')
    pool1 = Maxpooling_layer(conv1_1, 3, 3, 2, 2, 'SAME')

    conv2_1 = Conv_layer(pool1, 'conv2_1', 5, 5, 64, 3, 3, 'SAME')
    pool2 = Maxpooling_layer(conv2_1, 3, 3, 2, 2, 'SAME')

    pool_shape = pool2.get_shape().as_list()
    nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
    reshaped = tf.reshape(pool2, [-1, nodes])

    fc1 = Fullc_layer(reshaped, 'fc1', 256)
    fc2 = Fullc_layer(fc1, 'fc2', 256)
    fc2_drop = tf.nn.dropout(fc2, keep_prob)
    fc3 = Fullc_layer(fc2_drop, 'fc3', 12)

    logits = fc3
    return logits

if __name__ == '__main__':
    IMAGE_SIZE = 256
    IMAGE_CHANNELS = 3
    BATCH_SIZE = 32
    MIN_AFTER_DEQUEUE = 500
    CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
    NUM_THREADS = 4

    keep_prob = tf.placeholder(tf.float32, name = "keep_prob")

    #image, label = Read_TFRecords('Weed_InputData_Training.tfrecords*')
    image, label = Read_TFRecords('Weed_InputData_Training_Augmentation*')
    #image, label = Read_TFRecords('Weed_InputData_Augmentation*')
    image_p, label_p = Preprocess(image, label, IMAGE_SIZE)
    image_batch, label_batch = tf.train.shuffle_batch([image_p, label_p], batch_size = BATCH_SIZE, capacity = CAPACITY, min_after_dequeue = MIN_AFTER_DEQUEUE, num_threads = NUM_THREADS)

    logits = Model(image_batch, keep_prob)
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels = label_batch, logits = logits)
    loss = tf.reduce_mean(cross_entropy)
    train_step = tf.train.AdamOptimizer(1e-3).minimize(loss)

    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label_batch, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    saver = tf.train.Saver()
    with tf.Session() as sess:
        step = 0
        c_acc = 0
        start = time.time()
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess = sess, coord = coord)
        while True:
            sess.run(train_step, feed_dict = {keep_prob: 0.5})
            if step % 10 == 0:
                valid_loss, valid_accuracy = sess.run([loss, accuracy], feed_dict = {keep_prob: 1.0})
                print('Step {}, loss = {:.6f}, training accuracy = {:.6f}, total time = {:.3f}'.format(step, valid_loss, valid_accuracy, time.time() - start))
                if valid_accuracy == 1.0:
                    c_acc += 1
                    if c_acc == 10:
                        saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
                    elif c_acc == 12:
                        saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
                    elif c_acc == 13:
                        saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
                    elif c_acc == 15:
                        saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
                    elif c_acc == 18:
                        saver.save(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt', global_step = c_acc)
                    elif c_acc > 18:
                        break
            step += 1
        coord.request_stop()
        coord.join(threads)

五、模型验证(自行划分的120张图片作为验证集)

#!/bin/python3
import tensorflow as tf
import train_nn_TrainingAugmentation
import train_nn_Training
import train_nn
import Read_Data

IMAGE_SIZE = 256
BATCH_SIZE = 20
#img_valid, img_label = Read_Data.Read_TFRecords('Weed_InputData_Valid_120.tfre*')
img_valid, img_label = Read_Data.Read_TFRecords('Weed_InputData_Valid_Augmentation*')
img_p, label_p = Read_Data.Preprocess(img_valid, img_label, IMAGE_SIZE)
img_valid_batch, img_label_batch = tf.train.batch([img_p, label_p], batch_size = BATCH_SIZE)
keep_prob = tf.placeholder(tf.float32)
#logits = train_nn.Model(img_valid_batch, keep_prob)
#logits = train_nn_Training.Model(img_valid_batch, keep_prob)
logits = train_nn_TrainingAugmentation.Model(img_valid_batch, keep_prob)
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(img_label_batch, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess = sess, coord = coord)
    # BEST Model_Saver06_Final_Augmentation
    saver = tf.train.import_meta_graph('./Model_Saver06_Final_Augmentation/model_save.ckpt-13.meta')
    saver.restore(sess, './Model_Saver06_Final_Augmentation/model_save.ckpt-13')
    #saver = tf.train.import_meta_graph('./Model_Saver05_Final_Augmentation/model_save.ckpt-15.meta')
    #saver.restore(sess, './Model_Saver05_Final_Augmentation/model_save.ckpt-15')
    #saver = tf.train.import_meta_graph('./Model_Saver04_TrainingDataAugmentation/model_save.ckpt-18.meta')
    #saver.restore(sess, './Model_Saver04_TrainingDataAugmentation/model_save.ckpt-18')
    #saver = tf.train.import_meta_graph('./Model_Saver01/model_save.ckpt.meta')
    #saver.restore(sess, './Model_Saver01/model_save.ckpt')
    #saver = tf.train.import_meta_graph('./Model_Saver03_DataAugmentation/model_save.ckpt.meta')
    #saver.restore(sess, './Model_Saver03_DataAugmentation/model_save.ckpt')
    #saver = tf.train.import_meta_graph('./Model_Saver02_Training/model_save.ckpt-15.meta')
    #saver.restore(sess, './Model_Saver02_Training/model_save.ckpt-15')
    A = 0.0
    for i in range(6):
        ans = sess.run(accuracy, feed_dict = {keep_prob: 1.0})
        print(ans)
        A += ans
    print('Mean: ', A / 6)
    coord.request_stop()
    coord.join(threads)

六、模型预测

  • 将测试集图片做成TFRecord
#!/bin/python3
import tensorflow as tf
import os

def _int64_feature(v):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [v]))

def _bytes_feature(v):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [v]))

def Write_TFRecords(path, output):
    img = os.listdir(path)
    img.sort()
    writer = tf.python_io.TFRecordWriter(output)
    for i in range(len(img)):
        print(i, img[i])
        image = tf.gfile.FastGFile(path + img[i], 'rb')
        img_str = image.read()
        height, width, channels = tf.image.decode_png(img_str).eval().shape
        example = tf.train.Example(features = tf.train.Features(feature = \
        {'image': _bytes_feature(img_str), 'img_name': _bytes_feature(bytes(img[i], encoding = 'utf8'))}))
        writer.write(example.SerializeToString())
    writer.close()

if __name__ == '__main__':
    path = '/mnt/windows_E/DeepLearningClass/TestData/'
    output = './Weed_InputData_Final_Test.tfrecords'
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        Write_TFRecords(path, output)
  • 预测,并将结果写入文件
#!/bin/python3
import tensorflow as tf
import train_nn
import train_nn_TrainingAugmentation
import Read_Data

IMAGE_SIZE = 256
BATCH_SIZE = 6
with open('classes.txt') as f:
    con = f.read()
class_ = con.splitlines()
label = {}
for i in range(len(class_)):
    a, b = class_[i].split('#')
    label[int(a)] = b
print(label)
img_test, img_name = Read_Data.Read_Test_TFRecords('Weed_InputData_Final_Test*', IMAGE_SIZE)
img_test_batch, img_name_batch = tf.train.batch([img_test, img_name], batch_size = BATCH_SIZE)
keep_prob = tf.placeholder(tf.float32)
#logits = train_nn_TrainingAugmentation.Model(img_test_batch, keep_prob)
logits = train_nn.Model(img_test_batch, keep_prob)
pred = tf.argmax(tf.nn.softmax(logits), 1)
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess = sess, coord = coord)
    # results04
    saver = tf.train.import_meta_graph('./Model_Saver03_DataAugmentation/model_save.ckpt.meta')
    saver.restore(sess, './Model_Saver03_DataAugmentation/model_save.ckpt')
    # results03
    #saver = tf.train.import_meta_graph('./Model_Saver01/model_save.ckpt.meta')
    #saver.restore(sess, './Model_Saver01/model_save.ckpt')
    # results02
    #saver = tf.train.import_meta_graph('./Model_Saver05_Final_Augmentation/model_save.ckpt-15.meta')
    #saver.restore(sess, './Model_Saver05_Final_Augmentation/model_save.ckpt-15')
    # results01
    #saver = tf.train.import_meta_graph('./Model_Saver06_Final_Augmentation/model_save.ckpt-13.meta')
    #saver.restore(sess, './Model_Saver06_Final_Augmentation/model_save.ckpt-13')
    with open('TestData_results04.csv', 'w') as fw:
        fw.write(',species\n')
        for i in range(79):
            print('----------------------------------------')
            #a = sess.run(img_name_batch)
            ans = sess.run(pred, feed_dict = {keep_prob: 1.0})
            #print(img_name_batch.eval())
            for _ in range(len(ans)):
                fw.write(',' + label[ans[_]] + '\n')
            #print(a)
            # 以下两个eval()每运行一下 就会取出来一个batch
            #print(img_name_batch.eval())
            #print(img_name_batch.eval())
    coord.request_stop()
    coord.join(threads)

  • 更多代码详见Share
  • 在Kaggle上本题Plant Seedlings Classification的Discussion下,有不少参赛者开源了代码,值得学习

你可能感兴趣的:(Python,TensorFlow,Image,Classification,CNN,TensorFlow,Image,Classification,CNN)