Keras+DataGenerator

dateset.py

#! -*- coding: utf-8 -*-
from tensorflow import keras as k
from keras_preprocessing import image
import numpy as np
import math


class DataGenerator(k.utils.Sequence):
    def __init__(self, filenames_, labels_, image_size, batch_size, shuffle=True, augment=None):
        self.filenames = filenames_
        self.labels = labels_
        self.image_size = image_size
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.augment = augment
        self._shuffle()

    def _shuffle(self):
        self.indexes = np.arange(len(self.filenames))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    @staticmethod
    def _parse(filename, target_size):
        img = image.load_img(filename, target_size=target_size)
        img = image.img_to_array(img)
        img = k.applications.resnet50.preprocess_input(img)
        return img

    def on_epoch_end(self):
        self._shuffle()

    def __len__(self):
        return math.ceil(len(self.filenames) / float(self.batch_size))

    def __getitem__(self, idx):
        batch_indexes = self.indexes[idx * self.batch_size: (idx + 1) * self.batch_size]
        batch_x = np.array([self._parse(self.filenames[i], self.image_size) for i in batch_indexes])
        batch_y = np.array([self.labels[i] for i in batch_indexes])
        return batch_x, batch_y


def get_items(filepath_):
    filenames_ = []
    labels_ = []
    for line in open(filepath_):
        filename, label = line.strip('\n').split()
        filenames_.append(filename)
        labels_.append(label)
    return filenames_, labels_


if __name__ == '__main__':
    filenames, labels = get_items('data/test.txt')
    dg = DataGenerator(filenames, labels, (224, 224), 32)

    for x, y in dg:
        print(x.shape, y)

main.py

#! -*- coding: utf-8 -*-
from tensorflow import keras as k
from dataset import DataGenerator, get_items
import tensorflow as tf
import shutil
import os

# Data define
TRAIN_FILE = 'data/train.txt'
TEST_FILE = 'data/test.txt'
BATCH_SIZE = 32

# Image define
IMG_W = 224
IMG_H = 224
IMG_C = 3

# training define
NUM_CLASSES = 3
LR = 0.0001

# 目录
LOGDIR = './logs/'
CHECKPOINTS = './checkpoints/'

# GPU配置
gpu_option = tf.GPUOptions(
    allow_growth=True,
    # per_process_gpu_memory_fraction=0.99,
)
gpu_config = tf.ConfigProto(
    log_device_placement=False,
    allow_soft_placement=True,
    gpu_options=gpu_option,
)


def mkdir(path, delete=True):
    if os.path.exists(path):
        if delete:
            shutil.rmtree(path)
            os.makedirs(path)
    else:
        os.makedirs(path)


def main():
    # GPU配置
    sess = tf.Session(config=gpu_config)
    k.backend.set_session(sess)

    # 数据集
    train_filenames, train_labels = get_items('data/train.txt')
    test_filenames, test_labels = get_items('data/test.txt')

    # 宏定义
    image_size = (IMG_W, IMG_H)
    batch_size = BATCH_SIZE
    num_classes = NUM_CLASSES
    lr = LR
    logdir = LOGDIR
    ckdir = CHECKPOINTS

    train_gen = DataGenerator(train_filenames, train_labels, image_size, batch_size, shuffle=True)
    test_gen = DataGenerator(test_filenames, test_labels, image_size, batch_size, shuffle=False)

    # Model
    base_model = k.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(IMG_H, IMG_W, IMG_C))
    base_model.trainable = False

    x = base_model.output
    x = k.layers.Flatten()(x)
    x = k.layers.Dense(1024, activation='relu')(x)
    outputs = k.layers.Dense(num_classes, activation='softmax')(x)

    model = k.Model(inputs=base_model.inputs, outputs=outputs)
    model.summary()

    # 编译
    model.compile(optimizer=k.optimizers.Adam(lr),
                  loss=k.losses.sparse_categorical_crossentropy,
                  metrics=['accuracy'])

    mkdir(logdir, delete=True)
    mkdir(ckdir, delete=True)

    # 回调
    callbacks = [
        k.callbacks.EarlyStopping(monitor="val_loss", patience=10),
        k.callbacks.TensorBoard(logdir),
        k.callbacks.ModelCheckpoint(ckdir + "weights.{epoch:02d}-{val_loss:.2f}.hdf5",
                                    monitor="val_loss", save_best_only=True, save_weights_only=False)
    ]

    # 训练
    model.fit_generator(train_gen,
                        epochs=1000,
                        steps_per_epoch=len(train_gen),
                        validation_data=test_gen,
                        validation_steps=len(test_gen),
                        workers=8,
                        use_multiprocessing=True,
                        callbacks=callbacks)


if __name__ == '__main__':
    main()

你可能感兴趣的:(编程语言-算法)