Keras框架常用两种训练函数model.fit()和model.fit_generator()函数。
fit( x, y, batch_size=32, epochs=10, verbose=1, callbacks=None,
validation_split=0.0, validation_data=None, shuffle=True,
class_weight=None, sample_weight=None, initial_epoch=0)
优点:
缺点:
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1,
callbacks=None, validation_data=None, validation_steps=None,
class_weight=None, max_queue_size=10, workers=1,
use_multiprocessing=False, shuffle=True, initial_epoch=0)
generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例,
以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:
一个 (inputs, targets) 元组
一个 (inputs, targets, sample_weights) 元组。
这个元组(生成器的单个输出)组成了单个的 batch。 因此,这个元组中的所有数组长度必须相同(与这一个 batch 的大小相等)。不同的 batch 可能大小不同。 例如,一个 epoch 的最后一个 batch 往往比其他 batch 要小, 如果数据集的尺寸不能被batch size 整除。 生成器将无限地在数据集上循环。当运行到第 steps_per_epoch 时,记一个 epoch 结束。
优点:
缺点:
import os
import numpy as np
from keras.applications.resnet50 import ResNet50
from keras.preprocessing.image import ImageDataGenerator
import keras
import cv2
import random
from tqdm import tqdm
# 读取图像数据和标签数据
def data_set(img_path, xml_path):
imgSet = [] # 图像数据集合
labelSet = [] # 标签数据集合
imgfiles = os.listdir(img_path)
# 读取数据
for index in tqdm(range(len(imgfiles))):
img = cv2.imread(os.path.join(img_path,imgfiles[index]), 1)
img = cv2.resize(img, (224, 224)) / 255.0
imgSet.append(img)
xmlfile = os.path.splitext(imgfiles[index])[0] + '.xml'
# 此处可改为自己的获取标签的函数
label = get_label(os.path.join(xml_path, xmlfile))
labelSet.append(label)
imgTrain = np.asarray(imgSet,dtype=np.float32)
labelTrain = np.asarray(labelSet,dtype=np.int32)
index = [i for i in range(len(imgfiles))]
random.shuffle(index)
imgTrain = imgTrain[index]
labelTrain = labelTrain[index]
return imgTrain, labelTrain
if __name__ == '__main__':
# 启动网络
model = ResNet50(
weights=None,
classes=4
)
# 设计优化器
model.compile(optimizer=keras.optimizers.Adam(0.001),loss='categorical_crossentropy',metrics=['accuracy'])
# train
filepath = 'weight/resnet50_epoch-{epoch:02d}_loss-{loss:.4f}.h5'
checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True,mode='auto', save_weights_only=True)
callbacks_list = [checkpoint]
model.fit(x=img, y=label, epochs=200, batch_size=4, validation_split=0.1, verbose=1, shuffle=True)
import os
import numpy as np
from keras.applications.resnet50 import ResNet50
from keras.preprocessing.image import ImageDataGenerator
import keras
import cv2
import random
from tqdm import tqdm
# 读取图像数据和标签数据
def data_set(img_path, xml_path):
imgfiles = os.listdir(img_path) # 从图像文件夹读取图像集合
aug = ImageDataGenerator(rotation_range=10,
zoom_range=0.15,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.15,
horizontal_flip=True,
vertical_flip=Tru
fill_mode="nearest")
# 读取数据 生成器模式
for index in tqdm(range(len(imgfiles))):
img = cv2.imread(os.path.join(img_path, imgfiles[index]), 1)
img = cv2.resize(img, (224, 224)) / 255.0
imgEX = np.expand_dims(img, axis=0)
xmlfile = os.path.splitext(imgfiles[index])[0] + '.xml'
label = get_label(os.path.join(xml_path, xmlfile))
labelEX = np.expand_dims(label, axis=0)
aug.flow(imgEX,labelEX,batch_size=4)
yield imgEX, labelEX
if __name__ == '__main__':
# 启动网络
model = ResNet50(
weights=None,
classes=4
)
# 设计优化器
model.compile(optimizer=keras.optimizers.Adam(0.001),loss='categorical_crossentropy',metrics=['accuracy'])
# train
filepath = 'weight/resnet50_epoch-{epoch:02d}_loss-{loss:.4f}.h5'
checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True,mode='auto', save_weights_only=True)
callbacks_list = [checkpoint]
model.fit_generator(data_generate, steps_per_epoch=4000, epochs=2, verbose=1,callbacks=callbacks_list)
generator生成器,相当于将普通函数中的return关键字,换成yield关键字,并且将输入数据X与标签Y按网络要求设置好即可。