Keras迁移学习

配置文件

config.json

{
  "name": "political",
  "root":"../experiments",
  "train_data_filename": "data/train.txt",
  "test_data_filename": "data/test.txt",
  "train_batch_size": 32,
  "test_batch_size": 32,
  "learning_rate": 0.001,
  "img_w": 224,
  "img_h": 224,
  "epochs": 1000,
  "workers": 8
}

config.py

#! -*- coding: utf-8 -*-
from bunch import Bunch
import shutil
import json
import os


def mkdir(dirname, delete):
    if os.path.exists(dirname):
        if delete:
            shutil.rmtree(dirname)
            os.makedirs(dirname)
    else:
        os.makedirs(dirname)
    print('* Create %s succeed.' % dirname)


def read_json_file(filename):
    with open(filename) as f:
        config_json = json.load(f)

        config = Bunch(config_json)

    return config


def get_config(filename, delete=True):
    config = read_json_file(filename)

    config.logdir = os.path.join(config.root, config.name, "logs/")
    config.ckdir = os.path.join(config.root, config.name, "checkpoints/")

    mkdir(config.logdir, delete)
    mkdir(config.ckdir, delete)

    return config

数据读取方式

dataset.py

#! -*- coding: utf-8 -*-
from tensorflow import keras as k
from PIL import Image
import numpy as np
import random


class DataGenerator(k.utils.Sequence):
    def __init__(self, filename, batch_size, img_w, img_h, train=True):
        self.filename = filename
        self.batch_size = batch_size
        self.img_w = img_w
        self.img_h = img_h
        self.train = train
        self._init_data()

    def _init_data(self):
        self.items = []
        for line in open(self.filename):
            name, label = line.strip('\n').split()
            self.items.append((name, label))
        if self.train:
            random.shuffle(self.items)

    def _parse(self, filename):
        image = Image.open(filename)
        image = image.resize((self.img_w, self.img_h))
        image = np.asarray(image, dtype='float32')
        return image

    def __len__(self):
        return np.ceil(len(self.items) / float(self.batch_size)).astype(np.int)

    def __getitem__(self, idx):
        item_batch = self.items[idx * self.batch_size: (idx + 1) * self.batch_size]
        name_batch, label_batch = zip(*item_batch)
        x_batch = np.array([self._parse(filename) for filename in name_batch])
        y_batch = np.array(label_batch).astype(np.int)
        return x_batch, y_batch

训练

#! -*- coding: utf-8 -*-
from tensorflow import keras as k
from config import get_config
from dataset import DataGenerator
from models.resnet import myresnet50
import tensorflow as tf


gpu_config = tf.GPUOptions(
    allow_growth=True,
    # per_process_gpu_memory_fraction=0.99,
)
gpu_config = tf.ConfigProto(
    log_device_placement=False,
    allow_soft_placement=True,
    gpu_options=gpu_config,
)


def main():
    # 项目配置文件
    config = get_config('config.json', delete=True)

    # 配置GPU
    sess = tf.Session(config=gpu_config)
    k.backend.set_session(sess)

    # 加载数据集
    train_gen = DataGenerator(config.train_data_filename,
                              config.train_batch_size,
                              config.img_w,
                              config.img_h,
                              train=True)

    test_gen = DataGenerator(config.test_data_filename,
                             config.test_batch_size,
                             config.img_w,
                             config.img_h,
                             train=False)

    print("* Train batch num: %d" % len(train_gen))
    print("* Test batch num: %d" % len(test_gen))

    # 模型
    base_model = k.applications.ResNet50(weights='imagenet', include_top=False)
    model = myresnet50(base_model, 3)
    model.summary()

    model.compile(
        optimizer=k.optimizers.Adam(config.learning_rate),
        loss=k.losses.sparse_categorical_crossentropy,
        metrics=['accuracy']
    )

    # callbacks
    callbacks = [
        k.callbacks.EarlyStopping(monitor='val_loss', patience=5),
        k.callbacks.TensorBoard(config.logdir),
        k.callbacks.ModelCheckpoint(config.ckdir + "weights.{epoch:02.d}-{val_loss:.2f}.hdf5",
                                    monitor="val_loss",
                                    save_best_only=True,
                                    save_weights_only=False)
    ]

    # 训练
    histroy = model.fit_generator(train_gen,
                                  epochs=config.epochs,
                                  steps_per_epoch=len(train_gen),
                                  validation_data=test_gen,
                                  validation_steps=len(test_gen),
                                  workers=config.workers,
                                  max_queue_size=16,
                                  use_multiprocessing=True,
                                  callbacks=callbacks)


if __name__ == '__main__':
    main()

你可能感兴趣的:(深度学习-算法)