配置文件
config.json
{
"name": "political",
"root":"../experiments",
"train_data_filename": "data/train.txt",
"test_data_filename": "data/test.txt",
"train_batch_size": 32,
"test_batch_size": 32,
"learning_rate": 0.001,
"img_w": 224,
"img_h": 224,
"epochs": 1000,
"workers": 8
}
config.py
from bunch import Bunch
import shutil
import json
import os
def mkdir(dirname, delete):
if os.path.exists(dirname):
if delete:
shutil.rmtree(dirname)
os.makedirs(dirname)
else:
os.makedirs(dirname)
print('* Create %s succeed.' % dirname)
def read_json_file(filename):
with open(filename) as f:
config_json = json.load(f)
config = Bunch(config_json)
return config
def get_config(filename, delete=True):
config = read_json_file(filename)
config.logdir = os.path.join(config.root, config.name, "logs/")
config.ckdir = os.path.join(config.root, config.name, "checkpoints/")
mkdir(config.logdir, delete)
mkdir(config.ckdir, delete)
return config
数据读取方式
dataset.py
from tensorflow import keras as k
from PIL import Image
import numpy as np
import random
class DataGenerator(k.utils.Sequence):
def __init__(self, filename, batch_size, img_w, img_h, train=True):
self.filename = filename
self.batch_size = batch_size
self.img_w = img_w
self.img_h = img_h
self.train = train
self._init_data()
def _init_data(self):
self.items = []
for line in open(self.filename):
name, label = line.strip('\n').split()
self.items.append((name, label))
if self.train:
random.shuffle(self.items)
def _parse(self, filename):
image = Image.open(filename)
image = image.resize((self.img_w, self.img_h))
image = np.asarray(image, dtype='float32')
return image
def __len__(self):
return np.ceil(len(self.items) / float(self.batch_size)).astype(np.int)
def __getitem__(self, idx):
item_batch = self.items[idx * self.batch_size: (idx + 1) * self.batch_size]
name_batch, label_batch = zip(*item_batch)
x_batch = np.array([self._parse(filename) for filename in name_batch])
y_batch = np.array(label_batch).astype(np.int)
return x_batch, y_batch
训练
from tensorflow import keras as k
from config import get_config
from dataset import DataGenerator
from models.resnet import myresnet50
import tensorflow as tf
gpu_config = tf.GPUOptions(
allow_growth=True,
)
gpu_config = tf.ConfigProto(
log_device_placement=False,
allow_soft_placement=True,
gpu_options=gpu_config,
)
def main():
config = get_config('config.json', delete=True)
sess = tf.Session(config=gpu_config)
k.backend.set_session(sess)
train_gen = DataGenerator(config.train_data_filename,
config.train_batch_size,
config.img_w,
config.img_h,
train=True)
test_gen = DataGenerator(config.test_data_filename,
config.test_batch_size,
config.img_w,
config.img_h,
train=False)
print("* Train batch num: %d" % len(train_gen))
print("* Test batch num: %d" % len(test_gen))
base_model = k.applications.ResNet50(weights='imagenet', include_top=False)
model = myresnet50(base_model, 3)
model.summary()
model.compile(
optimizer=k.optimizers.Adam(config.learning_rate),
loss=k.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
)
callbacks = [
k.callbacks.EarlyStopping(monitor='val_loss', patience=5),
k.callbacks.TensorBoard(config.logdir),
k.callbacks.ModelCheckpoint(config.ckdir + "weights.{epoch:02.d}-{val_loss:.2f}.hdf5",
monitor="val_loss",
save_best_only=True,
save_weights_only=False)
]
histroy = model.fit_generator(train_gen,
epochs=config.epochs,
steps_per_epoch=len(train_gen),
validation_data=test_gen,
validation_steps=len(test_gen),
workers=config.workers,
max_queue_size=16,
use_multiprocessing=True,
callbacks=callbacks)
if __name__ == '__main__':
main()