大部分使用 keras 的同学使用 fit() 或者 fit_generator() 进行模型训练, 这两个 api 对于刚接触深度学习的同学非常友好和方便,但是由于其是非常深度的封装,对于希望自定义训练过程的同学就显得不是那么方便(从 torch 转 keras 的同学可能更喜欢自定义训练过程),而且,对于 GAN 这种需要分步进行训练的模型,也无法直接使用 fit 或者 fit_generator 直接训练的。因此,keras 提供了 train_on_batch 这个 api,对一个 mini-batch 的数据进行梯度更新。
总结优点如下:
下面介绍 train_on_batch 的使用
y_pred = Model.train_on_batch(
x,
y=None,
sample_weight=None,
class_weight=None,
reset_metrics=True,
return_dict=False,
)
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'])
y_pred = model.train_on_batch(x=image,y=label)
# y_pred 为标量
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(Adam, loss=['binary_crossentropy'], metrics=['accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# len(y_pred) == 2, y_pred[0]为loss, y_pred[1]为accuracy
model = keras.models.Model(inputs=inputs, outputs=[output1, output2])
model.compile(Adam,
loss=['binary_crossentropy', 'binary_crossentropy'],
metrics=['accuracy', 'accuracy'])
y_pred = model.train_on_batch(x=image,y=label)
# 查看model.metrics_names来了解返回列表中每个值的含义
注意!训练时对 para_model 操作,保存时对 model 做操作
import tensorflow as tf
import keras
import os
# 初始化GPU的使用个数
gpu = "0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
gpu_num = len(gpu.split(','))
# model初始化
with tf.device('/cpu:0'):# 使用多GPU时,先在CPU上初始化模型
model = YourModel(input_size, num_classes)
model.load_weights('*.h5') # 如果有权重需要加载,在这里实现
para_model = keras.utils.multi_gpu_model(model, gpus=gpu_num) # 在GPU上初始化多GPU模型
para_model.compile(optimizer, loss=[...], metrics=[...]) # 编译多GPU模型
# 训练和验证,对 para_model 使用 train_on_batch
def train():
para_model.train_on_batch(...)
def evaluate():
para_model.test_on_batch(...)
# 保存模型,注意!训练时对 para_model 操作,保存时对 model 做操作
# 不要使用 para_model.save() 或者 para_model.save_weights(),否则加载时会出问题
model.save('*.h5')
model.save_weights('*.h5')
由于无法使用callback,我们使用 keras.backend.get_value() 和 keras.backend.set_value() 来获取和设置当前学习率。举个栗子, 实现一下最简单阶梯下降学习率,每10个epoch,学习率下降0.1倍
import keras.backend as K
for epoch in range(100):
train_one_epoch()
evaluate()
# 每10个epoch,lr缩小0.1倍
if epoch%10==0 and epoch!=0:
lr = K.get_value(model.optimizer.lr) # 获取当前学习率
lr = lr * 0.1 # 学习率缩小0.1倍
K.set_value(model.optimizer.lr, lr) # 设置学习率
torch 的 dataloader 是目前为止我用过最好用的数据加载方式,使用 train_on_batch 一部分的原因是因为我能够用 torch dataloader 载入数据,然后用 train_on_batch 对模型进行训练,通过合理的控制 cpu worker 的使用个数和 batch_size 的大小,使模型的训练效率最大化
# 定义 torch dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self, root_list, transforms=None):
self.root_list = root_list
self.transforms = transforms
def __getitem__(self, idx):
# 假设是图像分类任务
image = ... # 读取单张图像
label = ... # 读取标签
if self.transforms is not None:
image = self.transforms(image)
return image, label # shape: (H,W,3), salar
def __len__(self):
return len(self.root_list)
# 自定义 collate_fn 使 dataloader 返回 numpy array
def collate_fn(batch):
# 这里的 batch 是 tuple 列表,[(image, label),(image, label),...]
image, label = zip(*batch)
image = np.asarray(image) # (batch_size, H, W, 3)
label = np.asarray(label) # (batch_size)
return image, label # 如果 datast 返回的图像是 ndarray,这样loader返回的也是 ndarray
# 定义dataset
train_dataset = Dataset(train_list)
valid_dataset = Dataset(valid_list)
test_dataset = Dataset(test_list)
# 定义 dataloader, 如果不使用自定义 collate_fn,
# 从 loader 取出的默认是 torch Tensor,需要做一个 .numpy()的转换
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
# 定义 train,evaluate,test
def train():
for i,(inputs, label) in enumerate(train_loader):
# 如果 inputs 和 label 是 torch Tensor
# 请用 inputs = inputs.numpy() 和 label = label.numpy() 转成 ndarray
y_pred = model.train_on_batch(inputs, label)
def evaluate():
for i,(inputs, label) in enumerate(valid_loader):
# 如果 inputs 和 label 是 Tensor,同上
y_pred = model.test_on_batch(inputs, label)
def test():
for i,(inputs, label) in enumerate(test_loader):
# 如果 inputs 和 label 是 Tensor,同上
y_pred = model.test_on_batch(inputs, label)
def run():
for epoch in num_epoch:
train()
evaluate()
test()
if __name__ == "__main__":
run()
还有一些使用 train_on_batch 的地方比如 GAN 的训练,这里就不介绍了,具体可以上 github 上搜索,例如 keras-dcgan。
keras 官方 api: train_on_batch